{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# 第一课\n",
    "\n",
    "褚则伟 zeweichu@gmail.com\n",
    "\n",
    "[参考资料 reference](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "什么是PyTorch?\n",
    "================\n",
    "\n",
    "PyTorch是一个基于Python的科学计算库，它有以下特点:\n",
    "\n",
    "- 类似于NumPy，但是它可以使用GPU\n",
    "- 可以用它定义深度学习模型，可以灵活地进行深度学习模型的训练和使用\n",
    "\n",
    "Tensors\n",
    "---------------\n",
    "\n",
    "\n",
    "Tensor类似与NumPy的ndarray，唯一的区别是Tensor可以在GPU上加速运算。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造一个未初始化的5x3矩阵:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "        [0.0000e+00, 4.7339e+30, 1.4347e-19],\n",
      "        [2.7909e+23, 1.8037e+28, 1.7237e+25],\n",
      "        [9.1041e-12, 6.2609e+22, 4.7428e+30],\n",
      "        [3.8001e-39, 0.0000e+00, 0.0000e+00]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.empty(5, 3)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建一个随机初始化的矩阵:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.4821, 0.3854, 0.8517],\n",
      "        [0.7962, 0.0632, 0.5409],\n",
      "        [0.8891, 0.6112, 0.7829],\n",
      "        [0.0715, 0.8069, 0.2608],\n",
      "        [0.3292, 0.0119, 0.2759]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand(5, 3)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建一个全部为0，类型为long的矩阵:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 0, 0],\n",
      "        [0, 0, 0],\n",
      "        [0, 0, 0],\n",
      "        [0, 0, 0],\n",
      "        [0, 0, 0]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.zeros(5, 3, dtype=torch.long)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从数据直接直接构建tensor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5.5000, 3.0000])\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([5.5, 3])\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以从一个已有的tensor构建一个tensor。这些方法会重用原来tensor的特征，例如，数据类型，除非提供新的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.]], dtype=torch.float64)\n",
      "tensor([[ 1.4793, -2.4772,  0.9738],\n",
      "        [ 2.0328,  1.3981,  1.7509],\n",
      "        [-0.7931, -0.0291, -0.6803],\n",
      "        [-1.2944, -0.7352, -0.9346],\n",
      "        [ 0.5917, -0.5149, -1.8149]])\n"
     ]
    }
   ],
   "source": [
    "x = x.new_ones(5, 3, dtype=torch.double)      # new_* methods take in sizes\n",
    "print(x)\n",
    "\n",
    "x = torch.randn_like(x, dtype=torch.float)    # override dtype!\n",
    "print(x)                                      # result has the same size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "得到tensor的形状:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 3])\n"
     ]
    }
   ],
   "source": [
    "print(x.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\"><h4>注意</h4><p>``torch.Size`` 返回的是一个tuple</p></div>\n",
    "\n",
    "Operations\n",
    "\n",
    "\n",
    "有很多种tensor运算。我们先介绍加法运算。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.7113, -1.5490,  1.4009],\n",
      "        [ 2.4590,  1.6504,  2.6889],\n",
      "        [-0.3609,  0.4950, -0.3357],\n",
      "        [-0.5029, -0.3086, -0.1498],\n",
      "        [ 1.2850, -0.3189, -0.8868]])\n"
     ]
    }
   ],
   "source": [
    "y = torch.rand(5, 3)\n",
    "print(x + y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "另一种着加法的写法\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.7113, -1.5490,  1.4009],\n",
      "        [ 2.4590,  1.6504,  2.6889],\n",
      "        [-0.3609,  0.4950, -0.3357],\n",
      "        [-0.5029, -0.3086, -0.1498],\n",
      "        [ 1.2850, -0.3189, -0.8868]])\n"
     ]
    }
   ],
   "source": [
    "print(torch.add(x, y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加法：把输出作为一个变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.7113, -1.5490,  1.4009],\n",
      "        [ 2.4590,  1.6504,  2.6889],\n",
      "        [-0.3609,  0.4950, -0.3357],\n",
      "        [-0.5029, -0.3086, -0.1498],\n",
      "        [ 1.2850, -0.3189, -0.8868]])\n"
     ]
    }
   ],
   "source": [
    "result = torch.empty(5, 3)\n",
    "torch.add(x, y, out=result)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "in-place加法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.7113, -1.5490,  1.4009],\n",
      "        [ 2.4590,  1.6504,  2.6889],\n",
      "        [-0.3609,  0.4950, -0.3357],\n",
      "        [-0.5029, -0.3086, -0.1498],\n",
      "        [ 1.2850, -0.3189, -0.8868]])\n"
     ]
    }
   ],
   "source": [
    "# adds x to y\n",
    "y.add_(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\"><h4>注意</h4><p>任何in-place的运算都会以``_``结尾。\n",
    "    举例来说：``x.copy_(y)``, ``x.t_()``, 会改变 ``x``。</p></div>\n",
    "\n",
    "各种类似NumPy的indexing都可以在PyTorch tensor上面使用。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-2.4772,  1.3981, -0.0291, -0.7352, -0.5149])\n"
     ]
    }
   ],
   "source": [
    "print(x[:, 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Resizing: 如果你希望resize/reshape一个tensor，可以使用``torch.view``："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(4, 4)\n",
    "y = x.view(16)\n",
    "z = x.view(-1, 8)  # the size -1 is inferred from other dimensions\n",
    "print(x.size(), y.size(), z.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果你有一个只有一个元素的tensor，使用``.item()``方法可以把里面的value变成Python数值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.4726])\n",
      "0.4726296067237854\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(1)\n",
    "print(x)\n",
    "print(x.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**更多阅读**\n",
    "\n",
    "\n",
    "  各种Tensor operations, 包括transposing, indexing, slicing,\n",
    "  mathematical operations, linear algebra, random numbers在\n",
    "  `<https://pytorch.org/docs/torch>`.\n",
    "\n",
    "Numpy和Tensor之间的转化\n",
    "------------\n",
    "\n",
    "在Torch Tensor和NumPy array之间相互转化非常容易。\n",
    "\n",
    "Torch Tensor和NumPy array会共享内存，所以改变其中一项也会改变另一项。\n",
    "\n",
    "把Torch Tensor转变成NumPy Array\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 1., 1., 1., 1.])\n"
     ]
    }
   ],
   "source": [
    "a = torch.ones(5)\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 1. 1. 1. 1.]\n"
     ]
    }
   ],
   "source": [
    "b = a.numpy()\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "改变numpy array里面的值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2., 2., 2., 2., 2.])\n",
      "[2. 2. 2. 2. 2.]\n"
     ]
    }
   ],
   "source": [
    "a.add_(1)\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "把NumPy ndarray转成Torch Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2. 2. 2. 2. 2.]\n",
      "tensor([2., 2., 2., 2., 2.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "a = np.ones(5)\n",
    "b = torch.from_numpy(a)\n",
    "np.add(a, 1, out=a)\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "所有CPU上的Tensor都支持转成numpy或者从numpy转成Tensor。\n",
    "\n",
    "CUDA Tensors\n",
    "------------\n",
    "\n",
    "使用``.to``方法，Tensor可以被移动到别的device上。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# let us run this cell only if CUDA is available\n",
    "# We will use ``torch.device`` objects to move tensors in and out of GPU\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")          # a CUDA device object\n",
    "    y = torch.ones_like(x, device=device)  # directly create a tensor on GPU\n",
    "    x = x.to(device)                       # or just use strings ``.to(\"cuda\")``\n",
    "    z = x + y\n",
    "    print(z)\n",
    "    print(z.to(\"cpu\", torch.double))       # ``.to`` can also change dtype together!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "热身: 用numpy实现两层神经网络\n",
    "--------------\n",
    "\n",
    "一个全连接ReLU神经网络，一个隐藏层，没有bias。用来从x预测y，使用L2 Loss。\n",
    "\n",
    "这一实现完全使用numpy来计算前向神经网络，loss，和反向传播。\n",
    "\n",
    "numpy ndarray是一个普通的n维array。它不知道任何关于深度学习或者梯度(gradient)的知识，也不知道计算图(computation graph)，只是一种用来计算数学运算的数据结构。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 34399246.46047344\n",
      "1 29023199.257758312\n",
      "2 25155679.85447208\n",
      "3 20344203.603057466\n",
      "4 14771404.625789404\n",
      "5 9796072.99431371\n",
      "6 6194144.749997159\n",
      "7 3948427.3657580013\n",
      "8 2637928.1726997104\n",
      "9 1879876.2597949505\n",
      "10 1424349.925182723\n",
      "11 1131684.579785501\n",
      "12 930879.9521737935\n",
      "13 783503.167740541\n",
      "14 669981.8287784329\n",
      "15 579151.6288421676\n",
      "16 504610.5781504087\n",
      "17 442295.18952143926\n",
      "18 389647.44224490353\n",
      "19 344718.3535892912\n",
      "20 306120.2245707266\n",
      "21 272728.24885829526\n",
      "22 243778.8617292929\n",
      "23 218485.92082002352\n",
      "24 196304.70602822883\n",
      "25 176774.2980280186\n",
      "26 159509.34934842546\n",
      "27 144200.52956072442\n",
      "28 130597.06878493169\n",
      "29 118484.47548850597\n",
      "30 107661.24303895692\n",
      "31 97973.75762285746\n",
      "32 89291.0096051952\n",
      "33 81500.46898789635\n",
      "34 74477.4654945682\n",
      "35 68139.90452489533\n",
      "36 62418.87519034026\n",
      "37 57241.53801123622\n",
      "38 52545.34658231941\n",
      "39 48280.5552386464\n",
      "40 44399.73653914068\n",
      "41 40864.495617471934\n",
      "42 37640.08489317873\n",
      "43 34695.77852549495\n",
      "44 32004.894008637555\n",
      "45 29545.09481447049\n",
      "46 27292.93700341219\n",
      "47 25232.87780747312\n",
      "48 23342.570881009553\n",
      "49 21606.76105421809\n",
      "50 20015.62357395961\n",
      "51 18551.83281521863\n",
      "52 17204.31407669751\n",
      "53 15962.736948706759\n",
      "54 14818.242254751764\n",
      "55 13762.251705340486\n",
      "56 12787.060032590252\n",
      "57 11885.95797873141\n",
      "58 11053.123737613136\n",
      "59 10282.617503272711\n",
      "60 9569.805676161515\n",
      "61 8909.467534754986\n",
      "62 8297.782408129178\n",
      "63 7731.121277369748\n",
      "64 7205.863671952578\n",
      "65 6718.146999962471\n",
      "66 6265.473531640673\n",
      "67 5845.100232214373\n",
      "68 5454.557838660972\n",
      "69 5091.658572234415\n",
      "70 4754.393958028546\n",
      "71 4440.682575260731\n",
      "72 4148.70793229529\n",
      "73 3877.022931816484\n",
      "74 3624.088506535617\n",
      "75 3388.5746286682042\n",
      "76 3169.088547995476\n",
      "77 2964.637382505168\n",
      "78 2774.073275503305\n",
      "79 2596.433385302534\n",
      "80 2430.76267026859\n",
      "81 2276.0929913609607\n",
      "82 2131.752323451521\n",
      "83 1997.0334011418258\n",
      "84 1871.251515936368\n",
      "85 1753.7448614349362\n",
      "86 1643.919519574932\n",
      "87 1541.289735192464\n",
      "88 1445.3733798948592\n",
      "89 1355.6688030350501\n",
      "90 1271.7809967407718\n",
      "91 1193.2972539295215\n",
      "92 1119.8689894828083\n",
      "93 1051.1890596219616\n",
      "94 986.9044505648076\n",
      "95 926.7286776893059\n",
      "96 870.3673474483486\n",
      "97 817.5707566117906\n",
      "98 768.1200077715573\n",
      "99 721.7693127164074\n",
      "100 678.327084576388\n",
      "101 637.5984844921132\n",
      "102 599.3471700265131\n",
      "103 563.480144773489\n",
      "104 529.8443636950776\n",
      "105 498.2900261297218\n",
      "106 468.69076555164696\n",
      "107 440.9141759159077\n",
      "108 414.8406348102356\n",
      "109 390.36201589159975\n",
      "110 367.377986904459\n",
      "111 345.794153113363\n",
      "112 325.53293001498525\n",
      "113 306.50254567681907\n",
      "114 288.6084052220689\n",
      "115 271.79635277136003\n",
      "116 255.99599171996437\n",
      "117 241.14900382305748\n",
      "118 227.1967308563582\n",
      "119 214.07331707113855\n",
      "120 201.72960304299005\n",
      "121 190.1215354419851\n",
      "122 179.20332461067048\n",
      "123 168.93031453109492\n",
      "124 159.26460190296683\n",
      "125 150.18028740393805\n",
      "126 141.6310351604903\n",
      "127 133.5845565128207\n",
      "128 126.00708973959819\n",
      "129 118.87233941614235\n",
      "130 112.15340039819878\n",
      "131 105.82589029792051\n",
      "132 99.86499912782936\n",
      "133 94.24914703127945\n",
      "134 88.95813583258435\n",
      "135 83.97204514587689\n",
      "136 79.27316202972057\n",
      "137 74.84491080200985\n",
      "138 70.66991561621043\n",
      "139 66.7338564785546\n",
      "140 63.02257660379944\n",
      "141 59.52313997962988\n",
      "142 56.22322968024125\n",
      "143 53.111763701729714\n",
      "144 50.175933060002905\n",
      "145 47.40606340622237\n",
      "146 44.793256923660664\n",
      "147 42.328047476976025\n",
      "148 40.00144856286997\n",
      "149 37.80536851048289\n",
      "150 35.73343107253782\n",
      "151 33.77738183530186\n",
      "152 31.930774425664392\n",
      "153 30.187309589532056\n",
      "154 28.541632056826323\n",
      "155 26.987624733348596\n",
      "156 25.52056466134328\n",
      "157 24.135140772349633\n",
      "158 22.82592867935182\n",
      "159 21.589141054848028\n",
      "160 20.42054265142189\n",
      "161 19.31668484832083\n",
      "162 18.27373881475532\n",
      "163 17.288158658562804\n",
      "164 16.35722592792538\n",
      "165 15.47708242556972\n",
      "166 14.64564401606641\n",
      "167 13.859533873388633\n",
      "168 13.116375098259695\n",
      "169 12.413894134407165\n",
      "170 11.749589132060152\n",
      "171 11.121654032769754\n",
      "172 10.527755403701851\n",
      "173 9.966244729776452\n",
      "174 9.43532189744742\n",
      "175 8.933071972010405\n",
      "176 8.458011087399335\n",
      "177 8.008768655278885\n",
      "178 7.583910598664408\n",
      "179 7.181953229124071\n",
      "180 6.801495862472622\n",
      "181 6.441562835264972\n",
      "182 6.100994571216053\n",
      "183 5.778713643608878\n",
      "184 5.473781077364653\n",
      "185 5.185201108079436\n",
      "186 4.910778594973227\n",
      "187 4.651157323562962\n",
      "188 4.405448410225631\n",
      "189 4.172946150445695\n",
      "190 3.953151092232221\n",
      "191 3.744908981306408\n",
      "192 3.5478244092334004\n",
      "193 3.3613194717897596\n",
      "194 3.1847221013676847\n",
      "195 3.017529877609752\n",
      "196 2.859281816477389\n",
      "197 2.7094552588397214\n",
      "198 2.567567563134434\n",
      "199 2.4332311668647577\n",
      "200 2.3060182574075156\n",
      "201 2.1855656999370474\n",
      "202 2.0714842934675706\n",
      "203 1.963429527723876\n",
      "204 1.8610826001126994\n",
      "205 1.7641435814092026\n",
      "206 1.6723644556551025\n",
      "207 1.5854221408589446\n",
      "208 1.5030539196422792\n",
      "209 1.425000717598524\n",
      "210 1.351040918987503\n",
      "211 1.2809841760961524\n",
      "212 1.2146023546998657\n",
      "213 1.151699249786657\n",
      "214 1.09209327048869\n",
      "215 1.035622604251914\n",
      "216 0.9821312023078086\n",
      "217 0.931398400382507\n",
      "218 0.8833282657308957\n",
      "219 0.8377648680868359\n",
      "220 0.7946102894611751\n",
      "221 0.7536759265506683\n",
      "222 0.7148752979674705\n",
      "223 0.6781061363107517\n",
      "224 0.6432513423030822\n",
      "225 0.6102056251947237\n",
      "226 0.5788694783475372\n",
      "227 0.5491564979966296\n",
      "228 0.5209877387920085\n",
      "229 0.49428305592302185\n",
      "230 0.4689618516832248\n",
      "231 0.4449467772589872\n",
      "232 0.4221771538530943\n",
      "233 0.4005837129745524\n",
      "234 0.3801063836563353\n",
      "235 0.3606948589685684\n",
      "236 0.34227967093321354\n",
      "237 0.3248108971814142\n",
      "238 0.3082453595318127\n",
      "239 0.29253330742340067\n",
      "240 0.27763364296473714\n",
      "241 0.26349628059737484\n",
      "242 0.25008661105388497\n",
      "243 0.23736707183615013\n",
      "244 0.2253000957467423\n",
      "245 0.2138523361624442\n",
      "246 0.20298921216044458\n",
      "247 0.19268673137415065\n",
      "248 0.1829142719891026\n",
      "249 0.17364132891347556\n",
      "250 0.1648403028301571\n",
      "251 0.1564884304843211\n",
      "252 0.14856550672688007\n",
      "253 0.14104626799439057\n",
      "254 0.13391149996667578\n",
      "255 0.12714129899063353\n",
      "256 0.12071506432087209\n",
      "257 0.11461697949344261\n",
      "258 0.10882972424002835\n",
      "259 0.10333904654553265\n",
      "260 0.09812616827155932\n",
      "261 0.0931780565261463\n",
      "262 0.08848217530870164\n",
      "263 0.08402546175986916\n",
      "264 0.07979485919700613\n",
      "265 0.07577840283923465\n",
      "266 0.07196528340443757\n",
      "267 0.06834633090302783\n",
      "268 0.06491078821045225\n",
      "269 0.06164958715982291\n",
      "270 0.05855323394724137\n",
      "271 0.055613749333235304\n",
      "272 0.05282348135771675\n",
      "273 0.050174186967151216\n",
      "274 0.04765819132181068\n",
      "275 0.045271438294485024\n",
      "276 0.04300319389006948\n",
      "277 0.04085019680263658\n",
      "278 0.03880575227345778\n",
      "279 0.03686467008847949\n",
      "280 0.035020834986065126\n",
      "281 0.033270091765219535\n",
      "282 0.03160778576828356\n",
      "283 0.03002859285883898\n",
      "284 0.028528860354370092\n",
      "285 0.02710476057141562\n",
      "286 0.025752260418639834\n",
      "287 0.024468202569374195\n",
      "288 0.023248340960852463\n",
      "289 0.022089463426076428\n",
      "290 0.020988894619024936\n",
      "291 0.0199437951726752\n",
      "292 0.018950933816160725\n",
      "293 0.01800781808698391\n",
      "294 0.01711204608500793\n",
      "295 0.016261162547766717\n",
      "296 0.015452712469537512\n",
      "297 0.014684969147452612\n",
      "298 0.013955631973752351\n",
      "299 0.013262568423046399\n",
      "300 0.012604847094777376\n",
      "301 0.011979553027742364\n",
      "302 0.011385439315700433\n",
      "303 0.010820996391120769\n",
      "304 0.010284594931599429\n",
      "305 0.009774981235579705\n",
      "306 0.009290757906358697\n",
      "307 0.008830808784232147\n",
      "308 0.008393777417063167\n",
      "309 0.007978536761423337\n",
      "310 0.007583920779683997\n",
      "311 0.007208915598860074\n",
      "312 0.006852552626458435\n",
      "313 0.006513949532633323\n",
      "314 0.006192232365425878\n",
      "315 0.005886426734936253\n",
      "316 0.0055958430902459\n",
      "317 0.005319724577409574\n",
      "318 0.00505742623277029\n",
      "319 0.004808048657916753\n",
      "320 0.004571049645157161\n",
      "321 0.00434578735628591\n",
      "322 0.00413168646938554\n",
      "323 0.003928177611013847\n",
      "324 0.0037348015880361344\n",
      "325 0.0035511516515125325\n",
      "326 0.00337649547244047\n",
      "327 0.003210468217165082\n",
      "328 0.0030526001815046697\n",
      "329 0.0029025877721886788\n",
      "330 0.002759953496700082\n",
      "331 0.0026243799521717052\n",
      "332 0.0024955282428992384\n",
      "333 0.002373029911103254\n",
      "334 0.0022565698554979346\n",
      "335 0.0021458674153799783\n",
      "336 0.0020406457772717784\n",
      "337 0.0019405998694530123\n",
      "338 0.0018454766429137211\n",
      "339 0.001755053210723282\n",
      "340 0.0016691117022354827\n",
      "341 0.0015873764219450975\n",
      "342 0.001509661141624596\n",
      "343 0.0014357855252587638\n",
      "344 0.0013655404837908268\n",
      "345 0.0012987690170491828\n",
      "346 0.0012352672659498882\n",
      "347 0.0011748904401224442\n",
      "348 0.0011174828717238274\n",
      "349 0.001062929250585763\n",
      "350 0.0010110289742089754\n",
      "351 0.0009616701215304243\n",
      "352 0.0009147416941442063\n",
      "353 0.0008701229479127731\n",
      "354 0.0008276969799764454\n",
      "355 0.0007873380259436494\n",
      "356 0.0007489547972389264\n",
      "357 0.000712451804667229\n",
      "358 0.000677739371845171\n",
      "359 0.000644728435464334\n",
      "360 0.0006133353967986715\n",
      "361 0.000583485133709109\n",
      "362 0.0005550950412157807\n",
      "363 0.0005280932731431832\n",
      "364 0.0005024110280030845\n",
      "365 0.00047798591228805434\n",
      "366 0.0004547507774551706\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "367 0.0004326546423136559\n",
      "368 0.0004116382458083261\n",
      "369 0.0003916440959886334\n",
      "370 0.0003726296356534275\n",
      "371 0.0003545443586216977\n",
      "372 0.000337347352488608\n",
      "373 0.00032099061370803334\n",
      "374 0.0003054229784132819\n",
      "375 0.00029061647064382485\n",
      "376 0.0002765299098361774\n",
      "377 0.0002631327221101076\n",
      "378 0.0002503865963973947\n",
      "379 0.0002382599294869431\n",
      "380 0.00022672670184804494\n",
      "381 0.00021575299560298047\n",
      "382 0.00020531375263207438\n",
      "383 0.000195381616896771\n",
      "384 0.00018593500698085453\n",
      "385 0.00017694494225329907\n",
      "386 0.00016839225855899982\n",
      "387 0.00016025517275686525\n",
      "388 0.00015251350815142156\n",
      "389 0.0001451491411549753\n",
      "390 0.0001381428245892601\n",
      "391 0.00013147417414693054\n",
      "392 0.00012512977608770297\n",
      "393 0.00011909308605343111\n",
      "394 0.00011334857979979945\n",
      "395 0.00010788480695473414\n",
      "396 0.00010268704883570024\n",
      "397 9.773868892276339e-05\n",
      "398 9.303020197524704e-05\n",
      "399 8.85491663624475e-05\n",
      "400 8.428485316645869e-05\n",
      "401 8.022778747190388e-05\n",
      "402 7.636668153099922e-05\n",
      "403 7.269236014951034e-05\n",
      "404 6.919607836124983e-05\n",
      "405 6.586822433827189e-05\n",
      "406 6.270255584885866e-05\n",
      "407 5.968876673919747e-05\n",
      "408 5.682053107105221e-05\n",
      "409 5.409095915616159e-05\n",
      "410 5.1493171389161614e-05\n",
      "411 4.902052422909128e-05\n",
      "412 4.666751255362057e-05\n",
      "413 4.442778311925004e-05\n",
      "414 4.229665906499858e-05\n",
      "415 4.026848185397114e-05\n",
      "416 3.833731276858471e-05\n",
      "417 3.6499489902161296e-05\n",
      "418 3.47508278392673e-05\n",
      "419 3.308721418833935e-05\n",
      "420 3.150227517115802e-05\n",
      "421 2.9993581650856873e-05\n",
      "422 2.855746554164107e-05\n",
      "423 2.719082047553995e-05\n",
      "424 2.5889677603252346e-05\n",
      "425 2.4650964056310747e-05\n",
      "426 2.3472062284026367e-05\n",
      "427 2.234982200401504e-05\n",
      "428 2.1281362386681205e-05\n",
      "429 2.026427108363315e-05\n",
      "430 1.9295935390705645e-05\n",
      "431 1.837421991578651e-05\n",
      "432 1.7497081369101593e-05\n",
      "433 1.6661834251218388e-05\n",
      "434 1.586646319751679e-05\n",
      "435 1.5109303894458321e-05\n",
      "436 1.4388405284291706e-05\n",
      "437 1.370208486427819e-05\n",
      "438 1.3048788200557073e-05\n",
      "439 1.2426643170448882e-05\n",
      "440 1.1834592393040145e-05\n",
      "441 1.1270674124878297e-05\n",
      "442 1.0733978520453249e-05\n",
      "443 1.0222848634947958e-05\n",
      "444 9.736143369136732e-06\n",
      "445 9.27267469036151e-06\n",
      "446 8.831344021766777e-06\n",
      "447 8.411269141029111e-06\n",
      "448 8.01120400023348e-06\n",
      "449 7.630301830674667e-06\n",
      "450 7.267514155906956e-06\n",
      "451 6.922046453620718e-06\n",
      "452 6.593111073936982e-06\n",
      "453 6.279880429162564e-06\n",
      "454 5.981637837795765e-06\n",
      "455 5.697647106767495e-06\n",
      "456 5.427145298606709e-06\n",
      "457 5.169599426636958e-06\n",
      "458 4.924305367745114e-06\n",
      "459 4.690669682311872e-06\n",
      "460 4.468174841019597e-06\n",
      "461 4.256343870589204e-06\n",
      "462 4.054581997203786e-06\n",
      "463 3.862428795746305e-06\n",
      "464 3.67941417005127e-06\n",
      "465 3.50524041752366e-06\n",
      "466 3.339264640996386e-06\n",
      "467 3.1811641594010322e-06\n",
      "468 3.0305596775569415e-06\n",
      "469 2.8871296080016895e-06\n",
      "470 2.7505179893217363e-06\n",
      "471 2.6204108493977605e-06\n",
      "472 2.496484420691004e-06\n",
      "473 2.3784400359079537e-06\n",
      "474 2.266014000554243e-06\n",
      "475 2.1589189914752615e-06\n",
      "476 2.0569210781071096e-06\n",
      "477 1.959750825498873e-06\n",
      "478 1.867203798935969e-06\n",
      "479 1.7790548691072608e-06\n",
      "480 1.6950719025924979e-06\n",
      "481 1.6150665241638997e-06\n",
      "482 1.5388752061694276e-06\n",
      "483 1.4662834016989446e-06\n",
      "484 1.3971344832895556e-06\n",
      "485 1.3312570753140638e-06\n",
      "486 1.2684946752376657e-06\n",
      "487 1.2087236535163552e-06\n",
      "488 1.1517952524068044e-06\n",
      "489 1.0975341827852709e-06\n",
      "490 1.0458478627989778e-06\n",
      "491 9.966057254836819e-07\n",
      "492 9.49690406945525e-07\n",
      "493 9.04995815168244e-07\n",
      "494 8.624191220382796e-07\n",
      "495 8.218529042471487e-07\n",
      "496 7.831982460890191e-07\n",
      "497 7.463699242750524e-07\n",
      "498 7.112838972272693e-07\n",
      "499 6.778580009634641e-07\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# N is batch size; D_in is input dimension;\n",
    "# H is hidden dimension; D_out is output dimension.\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# Create random input and output data\n",
    "x = np.random.randn(N, D_in)\n",
    "y = np.random.randn(N, D_out)\n",
    "\n",
    "# Randomly initialize weights\n",
    "w1 = np.random.randn(D_in, H)\n",
    "w2 = np.random.randn(H, D_out)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    # Forward pass: compute predicted y\n",
    "    h = x.dot(w1)\n",
    "    h_relu = np.maximum(h, 0)\n",
    "    y_pred = h_relu.dot(w2)\n",
    "\n",
    "    # Compute and print loss\n",
    "    loss = np.square(y_pred - y).sum()\n",
    "    print(t, loss)\n",
    "\n",
    "    # Backprop to compute gradients of w1 and w2 with respect to loss\n",
    "    \n",
    "    # loss = (y_pred - y) ** 2\n",
    "    grad_y_pred = 2.0 * (y_pred - y)\n",
    "    # \n",
    "    grad_w2 = h_relu.T.dot(grad_y_pred)\n",
    "    grad_h_relu = grad_y_pred.dot(w2.T)\n",
    "    grad_h = grad_h_relu.copy()\n",
    "    grad_h[h < 0] = 0\n",
    "    grad_w1 = x.T.dot(grad_h)\n",
    "\n",
    "    # Update weights\n",
    "    w1 -= learning_rate * grad_w1\n",
    "    w2 -= learning_rate * grad_w2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "PyTorch: Tensors\n",
    "----------------\n",
    "\n",
    "这次我们使用PyTorch tensors来创建前向神经网络，计算损失，以及反向传播。\n",
    "\n",
    "一个PyTorch Tensor很像一个numpy的ndarray。但是它和numpy ndarray最大的区别是，PyTorch Tensor可以在CPU或者GPU上运算。如果想要在GPU上运算，就需要把Tensor换成cuda类型。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 31704728.0\n",
      "1 25331164.0\n",
      "2 22378086.0\n",
      "3 19262238.0\n",
      "4 15348289.0\n",
      "5 11017595.0\n",
      "6 7356282.0\n",
      "7 4705923.5\n",
      "8 3027346.5\n",
      "9 2012536.375\n",
      "10 1409662.25\n",
      "11 1041771.75\n",
      "12 807321.0625\n",
      "13 649262.0\n",
      "14 536533.1875\n",
      "15 451980.875\n",
      "16 385983.53125\n",
      "17 332925.53125\n",
      "18 289368.1875\n",
      "19 253030.78125\n",
      "20 222354.703125\n",
      "21 196214.3125\n",
      "22 173766.515625\n",
      "23 154378.140625\n",
      "24 137539.375\n",
      "25 122867.1015625\n",
      "26 110037.3515625\n",
      "27 98769.4921875\n",
      "28 88842.109375\n",
      "29 80063.15625\n",
      "30 72279.015625\n",
      "31 65361.66796875\n",
      "32 59195.42578125\n",
      "33 53687.4453125\n",
      "34 48757.57421875\n",
      "35 44338.4453125\n",
      "36 40370.34765625\n",
      "37 36803.1484375\n",
      "38 33587.4453125\n",
      "39 30684.1640625\n",
      "40 28059.435546875\n",
      "41 25683.255859375\n",
      "42 23528.814453125\n",
      "43 21570.8515625\n",
      "44 19792.4296875\n",
      "45 18175.244140625\n",
      "46 16704.6640625\n",
      "47 15364.2578125\n",
      "48 14141.7509765625\n",
      "49 13026.609375\n",
      "50 12007.3115234375\n",
      "51 11075.3896484375\n",
      "52 10221.8857421875\n",
      "53 9439.876953125\n",
      "54 8722.13671875\n",
      "55 8063.46826171875\n",
      "56 7458.20703125\n",
      "57 6901.8876953125\n",
      "58 6390.34375\n",
      "59 5919.4794921875\n",
      "60 5485.79345703125\n",
      "61 5086.119140625\n",
      "62 4718.2138671875\n",
      "63 4378.970703125\n",
      "64 4065.92578125\n",
      "65 3776.7900390625\n",
      "66 3509.54296875\n",
      "67 3262.43408203125\n",
      "68 3033.942626953125\n",
      "69 2822.52490234375\n",
      "70 2627.182373046875\n",
      "71 2446.365966796875\n",
      "72 2278.8046875\n",
      "73 2123.408447265625\n",
      "74 1979.00146484375\n",
      "75 1845.013427734375\n",
      "76 1720.6822509765625\n",
      "77 1605.2548828125\n",
      "78 1498.001953125\n",
      "79 1398.356201171875\n",
      "80 1305.7220458984375\n",
      "81 1219.5579833984375\n",
      "82 1139.3939208984375\n",
      "83 1064.7841796875\n",
      "84 995.3250732421875\n",
      "85 930.6298217773438\n",
      "86 870.3472900390625\n",
      "87 814.1729125976562\n",
      "88 761.8153686523438\n",
      "89 713.0128784179688\n",
      "90 667.50048828125\n",
      "91 625.0264892578125\n",
      "92 585.3772583007812\n",
      "93 548.3762817382812\n",
      "94 513.8129272460938\n",
      "95 481.5259094238281\n",
      "96 451.376708984375\n",
      "97 423.1982116699219\n",
      "98 396.865234375\n",
      "99 372.23583984375\n",
      "100 349.208984375\n",
      "101 327.65960693359375\n",
      "102 307.49652099609375\n",
      "103 288.6243591308594\n",
      "104 270.9569396972656\n",
      "105 254.41790771484375\n",
      "106 238.9322052001953\n",
      "107 224.42202758789062\n",
      "108 210.82664489746094\n",
      "109 198.08383178710938\n",
      "110 186.14157104492188\n",
      "111 174.94784545898438\n",
      "112 164.45217895507812\n",
      "113 154.6090850830078\n",
      "114 145.38900756835938\n",
      "115 136.7398681640625\n",
      "116 128.62008666992188\n",
      "117 121.001708984375\n",
      "118 113.84794616699219\n",
      "119 107.13176727294922\n",
      "120 100.82424926757812\n",
      "121 94.90043640136719\n",
      "122 89.33421325683594\n",
      "123 84.10637664794922\n",
      "124 79.19412994384766\n",
      "125 74.57848358154297\n",
      "126 70.23960876464844\n",
      "127 66.15946197509766\n",
      "128 62.32460403442383\n",
      "129 58.7183723449707\n",
      "130 55.32723617553711\n",
      "131 52.13628387451172\n",
      "132 49.13447570800781\n",
      "133 46.310585021972656\n",
      "134 43.65383529663086\n",
      "135 41.152828216552734\n",
      "136 38.799072265625\n",
      "137 36.583656311035156\n",
      "138 34.49782943725586\n",
      "139 32.53558349609375\n",
      "140 30.6860294342041\n",
      "141 28.94465446472168\n",
      "142 27.304447174072266\n",
      "143 25.759523391723633\n",
      "144 24.304840087890625\n",
      "145 22.93392562866211\n",
      "146 21.641254425048828\n",
      "147 20.42369842529297\n",
      "148 19.276079177856445\n",
      "149 18.194564819335938\n",
      "150 17.175493240356445\n",
      "151 16.214174270629883\n",
      "152 15.308029174804688\n",
      "153 14.454139709472656\n",
      "154 13.648143768310547\n",
      "155 12.88845157623291\n",
      "156 12.171833038330078\n",
      "157 11.49567699432373\n",
      "158 10.85841178894043\n",
      "159 10.256678581237793\n",
      "160 9.689424514770508\n",
      "161 9.154097557067871\n",
      "162 8.64884090423584\n",
      "163 8.172189712524414\n",
      "164 7.721974849700928\n",
      "165 7.297136306762695\n",
      "166 6.8962836265563965\n",
      "167 6.5177459716796875\n",
      "168 6.160311698913574\n",
      "169 5.822811126708984\n",
      "170 5.5043110847473145\n",
      "171 5.203525066375732\n",
      "172 4.919389724731445\n",
      "173 4.651163101196289\n",
      "174 4.3978190422058105\n",
      "175 4.158350944519043\n",
      "176 3.9322471618652344\n",
      "177 3.718606948852539\n",
      "178 3.516770839691162\n",
      "179 3.3262054920196533\n",
      "180 3.1460940837860107\n",
      "181 2.975762367248535\n",
      "182 2.814879894256592\n",
      "183 2.662900447845459\n",
      "184 2.5192079544067383\n",
      "185 2.3834681510925293\n",
      "186 2.255030393600464\n",
      "187 2.13377046585083\n",
      "188 2.0190846920013428\n",
      "189 1.9105591773986816\n",
      "190 1.807981014251709\n",
      "191 1.7110538482666016\n",
      "192 1.6193859577178955\n",
      "193 1.5326906442642212\n",
      "194 1.4506947994232178\n",
      "195 1.373248815536499\n",
      "196 1.2998838424682617\n",
      "197 1.2304624319076538\n",
      "198 1.1650127172470093\n",
      "199 1.1028441190719604\n",
      "200 1.0442299842834473\n",
      "201 0.9886825084686279\n",
      "202 0.9362077713012695\n",
      "203 0.8864397406578064\n",
      "204 0.8394078016281128\n",
      "205 0.7948980927467346\n",
      "206 0.7528337836265564\n",
      "207 0.7129263281822205\n",
      "208 0.6751680374145508\n",
      "209 0.6395058035850525\n",
      "210 0.6058014035224915\n",
      "211 0.573722243309021\n",
      "212 0.5434805750846863\n",
      "213 0.5148582458496094\n",
      "214 0.48777079582214355\n",
      "215 0.462094783782959\n",
      "216 0.4378334879875183\n",
      "217 0.41474175453186035\n",
      "218 0.3928961455821991\n",
      "219 0.37232136726379395\n",
      "220 0.35279765725135803\n",
      "221 0.3343387842178345\n",
      "222 0.31676602363586426\n",
      "223 0.3001691997051239\n",
      "224 0.2844657897949219\n",
      "225 0.26963645219802856\n",
      "226 0.2555326223373413\n",
      "227 0.24219271540641785\n",
      "228 0.22961300611495972\n",
      "229 0.21758520603179932\n",
      "230 0.20622654259204865\n",
      "231 0.19550156593322754\n",
      "232 0.18533945083618164\n",
      "233 0.17566744983196259\n",
      "234 0.16653285920619965\n",
      "235 0.15787597000598907\n",
      "236 0.14970409870147705\n",
      "237 0.14190873503684998\n",
      "238 0.13456779718399048\n",
      "239 0.12759016454219818\n",
      "240 0.12096268683671951\n",
      "241 0.11470359563827515\n",
      "242 0.1087842658162117\n",
      "243 0.10314527899026871\n",
      "244 0.09780357778072357\n",
      "245 0.09277193248271942\n",
      "246 0.08799058943986893\n",
      "247 0.0834306925535202\n",
      "248 0.07912513613700867\n",
      "249 0.0750374048948288\n",
      "250 0.07118058204650879\n",
      "251 0.06751800328493118\n",
      "252 0.06403960287570953\n",
      "253 0.06074457988142967\n",
      "254 0.05762597173452377\n",
      "255 0.05466882884502411\n",
      "256 0.0518682561814785\n",
      "257 0.04920265078544617\n",
      "258 0.04668186977505684\n",
      "259 0.044272683560848236\n",
      "260 0.042025547474622726\n",
      "261 0.03986666351556778\n",
      "262 0.037817493081092834\n",
      "263 0.03588436171412468\n",
      "264 0.03405837342143059\n",
      "265 0.03232688084244728\n",
      "266 0.030674563720822334\n",
      "267 0.029108863323926926\n",
      "268 0.027641309425234795\n",
      "269 0.026221055537462234\n",
      "270 0.024893635883927345\n",
      "271 0.02361663617193699\n",
      "272 0.022424183785915375\n",
      "273 0.02127956785261631\n",
      "274 0.020195599645376205\n",
      "275 0.019174542278051376\n",
      "276 0.01821214333176613\n",
      "277 0.01728914864361286\n",
      "278 0.016413141041994095\n",
      "279 0.01559178251773119\n",
      "280 0.014809946529567242\n",
      "281 0.014066735282540321\n",
      "282 0.01335320807993412\n",
      "283 0.012697878293693066\n",
      "284 0.012057363986968994\n",
      "285 0.011450453661382198\n",
      "286 0.010880804620683193\n",
      "287 0.01034202054142952\n",
      "288 0.009831000119447708\n",
      "289 0.009346149861812592\n",
      "290 0.008878068067133427\n",
      "291 0.00844407919794321\n",
      "292 0.008024066686630249\n",
      "293 0.007635605521500111\n",
      "294 0.0072587537579238415\n",
      "295 0.0069105857983231544\n",
      "296 0.006573254242539406\n",
      "297 0.006256978493183851\n",
      "298 0.005949943792074919\n",
      "299 0.005672339349985123\n",
      "300 0.005388857331126928\n",
      "301 0.0051320018246769905\n",
      "302 0.004887753631919622\n",
      "303 0.004658843856304884\n",
      "304 0.0044357734732329845\n",
      "305 0.004228176549077034\n",
      "306 0.004027842078357935\n",
      "307 0.003840153571218252\n",
      "308 0.0036594069097191095\n",
      "309 0.003490033093839884\n",
      "310 0.0033292584121227264\n",
      "311 0.0031785538885742426\n",
      "312 0.003031808650121093\n",
      "313 0.002896031131967902\n",
      "314 0.0027637507300823927\n",
      "315 0.002640662482008338\n",
      "316 0.0025206280406564474\n",
      "317 0.0024077417328953743\n",
      "318 0.0022998738568276167\n",
      "319 0.0022006493527442217\n",
      "320 0.002101228339597583\n",
      "321 0.0020116977393627167\n",
      "322 0.0019245940493419766\n",
      "323 0.0018393839709460735\n",
      "324 0.0017626716289669275\n",
      "325 0.001689193886704743\n",
      "326 0.0016162614338099957\n",
      "327 0.0015509161166846752\n",
      "328 0.0014848458813503385\n",
      "329 0.0014197597047314048\n",
      "330 0.0013633108465000987\n",
      "331 0.0013077231124043465\n",
      "332 0.001255737035535276\n",
      "333 0.001203768653795123\n",
      "334 0.0011556316167116165\n",
      "335 0.001109623583033681\n",
      "336 0.0010652983328327537\n",
      "337 0.0010259364498779178\n",
      "338 0.0009847141336649656\n",
      "339 0.0009464982431381941\n",
      "340 0.0009106658981181681\n",
      "341 0.0008753696456551552\n",
      "342 0.0008441813406534493\n",
      "343 0.0008131045615300536\n",
      "344 0.0007834337884560227\n",
      "345 0.0007538509089499712\n",
      "346 0.0007265112362802029\n",
      "347 0.0007019559852778912\n",
      "348 0.0006759824464097619\n",
      "349 0.0006510045495815575\n",
      "350 0.0006284148548729718\n",
      "351 0.0006068204529583454\n",
      "352 0.0005856421194039285\n",
      "353 0.0005672844126820564\n",
      "354 0.0005473798955790699\n",
      "355 0.0005280547775328159\n",
      "356 0.0005113428342156112\n",
      "357 0.0004943141248077154\n",
      "358 0.00047874540905468166\n",
      "359 0.00046168401604518294\n",
      "360 0.000447523663751781\n",
      "361 0.0004326198832131922\n",
      "362 0.00041988512384705245\n",
      "363 0.00040688799344934523\n",
      "364 0.0003942836483474821\n",
      "365 0.000381028454285115\n",
      "366 0.0003701212117448449\n",
      "367 0.00035913955071009696\n",
      "368 0.0003480427258182317\n",
      "369 0.00033798906952142715\n",
      "370 0.00032894761534407735\n",
      "371 0.0003196335455868393\n",
      "372 0.0003099186287727207\n",
      "373 0.0003019550640601665\n",
      "374 0.00029274728149175644\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "375 0.0002844816190190613\n",
      "376 0.00027625024085864425\n",
      "377 0.0002687727683223784\n",
      "378 0.0002608516369946301\n",
      "379 0.00025311342324130237\n",
      "380 0.0002469048195052892\n",
      "381 0.00024049097555689514\n",
      "382 0.0002342124644201249\n",
      "383 0.00022811403323430568\n",
      "384 0.00022231723414734006\n",
      "385 0.0002166029589716345\n",
      "386 0.00021077181736472994\n",
      "387 0.00020510501053649932\n",
      "388 0.00020020001102238894\n",
      "389 0.0001948442222783342\n",
      "390 0.00018990584067068994\n",
      "391 0.00018529882072471082\n",
      "392 0.00018070911755785346\n",
      "393 0.00017650797963142395\n",
      "394 0.00017214834224432707\n",
      "395 0.0001683011942077428\n",
      "396 0.00016451899136882275\n",
      "397 0.00016050187696237117\n",
      "398 0.00015686434926465154\n",
      "399 0.00015321985119953752\n",
      "400 0.0001501761726103723\n",
      "401 0.00014639270375482738\n",
      "402 0.00014274154091253877\n",
      "403 0.0001396275474689901\n",
      "404 0.0001364489580737427\n",
      "405 0.00013346801279112697\n",
      "406 0.00013024920190218836\n",
      "407 0.00012755846546497196\n",
      "408 0.00012532222899608314\n",
      "409 0.0001224723382620141\n",
      "410 0.00011974618973908946\n",
      "411 0.00011740042100427672\n",
      "412 0.00011441943206591532\n",
      "413 0.00011229746451135725\n",
      "414 0.00010995937191182747\n",
      "415 0.00010784588812384754\n",
      "416 0.00010610915342113003\n",
      "417 0.0001038835325744003\n",
      "418 0.00010166718857362866\n",
      "419 9.979418973671272e-05\n",
      "420 9.793229401111603e-05\n",
      "421 9.590695117367432e-05\n",
      "422 9.412408689968288e-05\n",
      "423 9.244915418094024e-05\n",
      "424 9.07004505279474e-05\n",
      "425 8.880807581590489e-05\n",
      "426 8.733373397262767e-05\n",
      "427 8.574980893172324e-05\n",
      "428 8.392545714741573e-05\n",
      "429 8.241042814916e-05\n",
      "430 8.080529369181022e-05\n",
      "431 7.939618808450177e-05\n",
      "432 7.762608584016562e-05\n",
      "433 7.651503256056458e-05\n",
      "434 7.531026494689286e-05\n",
      "435 7.377369183814153e-05\n",
      "436 7.305829058168456e-05\n",
      "437 7.153345359256491e-05\n",
      "438 7.028930122032762e-05\n",
      "439 6.921228487044573e-05\n",
      "440 6.803637370467186e-05\n",
      "441 6.695889896946028e-05\n",
      "442 6.582081550732255e-05\n",
      "443 6.459224823629484e-05\n",
      "444 6.373634096235037e-05\n",
      "445 6.257549830479547e-05\n",
      "446 6.140403274912387e-05\n",
      "447 6.0855145420646295e-05\n",
      "448 5.961475471849553e-05\n",
      "449 5.8864923630608246e-05\n",
      "450 5.767813490820117e-05\n",
      "451 5.6913944717962295e-05\n",
      "452 5.6101172958733514e-05\n",
      "453 5.5124517530202866e-05\n",
      "454 5.3974403272150084e-05\n",
      "455 5.3289859351934865e-05\n",
      "456 5.267193409963511e-05\n",
      "457 5.193992910790257e-05\n",
      "458 5.1057362725259736e-05\n",
      "459 5.001332101528533e-05\n",
      "460 4.934622847940773e-05\n",
      "461 4.873232319368981e-05\n",
      "462 4.794801498064771e-05\n",
      "463 4.7256213292712346e-05\n",
      "464 4.667185203288682e-05\n",
      "465 4.5966684410814196e-05\n",
      "466 4.526913107838482e-05\n",
      "467 4.486504985834472e-05\n",
      "468 4.413699934957549e-05\n",
      "469 4.358588557806797e-05\n",
      "470 4.305447146180086e-05\n",
      "471 4.2635525460354984e-05\n",
      "472 4.186580190435052e-05\n",
      "473 4.1199065890396014e-05\n",
      "474 4.055891258758493e-05\n",
      "475 4.017409810330719e-05\n",
      "476 3.963488052249886e-05\n",
      "477 3.913479667971842e-05\n",
      "478 3.8683563616359606e-05\n",
      "479 3.806965833064169e-05\n",
      "480 3.7681969843106344e-05\n",
      "481 3.737308725249022e-05\n",
      "482 3.669063517008908e-05\n",
      "483 3.630801438703202e-05\n",
      "484 3.584063597372733e-05\n",
      "485 3.539228782756254e-05\n",
      "486 3.4903492633020505e-05\n",
      "487 3.4551478165667504e-05\n",
      "488 3.4130087442463264e-05\n",
      "489 3.377102984813973e-05\n",
      "490 3.320474206702784e-05\n",
      "491 3.283058322267607e-05\n",
      "492 3.254006151109934e-05\n",
      "493 3.1907922675600275e-05\n",
      "494 3.1705902074463665e-05\n",
      "495 3.147914321743883e-05\n",
      "496 3.111985279247165e-05\n",
      "497 3.079450470977463e-05\n",
      "498 3.0240607884479687e-05\n",
      "499 2.9828101105522364e-05\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cpu\")\n",
    "# device = torch.device(\"cuda:0\") # Uncomment this to run on GPU\n",
    "\n",
    "# N is batch size; D_in is input dimension;\n",
    "# H is hidden dimension; D_out is output dimension.\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# Create random input and output data\n",
    "x = torch.randn(N, D_in, device=device, dtype=dtype)\n",
    "y = torch.randn(N, D_out, device=device, dtype=dtype)\n",
    "\n",
    "# Randomly initialize weights\n",
    "w1 = torch.randn(D_in, H, device=device, dtype=dtype)\n",
    "w2 = torch.randn(H, D_out, device=device, dtype=dtype)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    # Forward pass: compute predicted y\n",
    "    h = x.mm(w1)\n",
    "    h_relu = h.clamp(min=0)\n",
    "    y_pred = h_relu.mm(w2)\n",
    "\n",
    "    # Compute and print loss\n",
    "    loss = (y_pred - y).pow(2).sum().item()\n",
    "    print(t, loss)\n",
    "\n",
    "    # Backprop to compute gradients of w1 and w2 with respect to loss\n",
    "    grad_y_pred = 2.0 * (y_pred - y)\n",
    "    grad_w2 = h_relu.t().mm(grad_y_pred)\n",
    "    grad_h_relu = grad_y_pred.mm(w2.t())\n",
    "    grad_h = grad_h_relu.clone()\n",
    "    grad_h[h < 0] = 0\n",
    "    grad_w1 = x.t().mm(grad_h)\n",
    "\n",
    "    # Update weights using gradient descent\n",
    "    w1 -= learning_rate * grad_w1\n",
    "    w2 -= learning_rate * grad_w2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "简单的autograd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2.)\n",
      "tensor(1.)\n",
      "tensor(1.)\n"
     ]
    }
   ],
   "source": [
    "# Create tensors.\n",
    "x = torch.tensor(1., requires_grad=True)\n",
    "w = torch.tensor(2., requires_grad=True)\n",
    "b = torch.tensor(3., requires_grad=True)\n",
    "\n",
    "# Build a computational graph.\n",
    "y = w * x + b    # y = 2 * x + 3\n",
    "\n",
    "# Compute gradients.\n",
    "y.backward()\n",
    "\n",
    "# Print out the gradients.\n",
    "print(x.grad)    # x.grad = 2 \n",
    "print(w.grad)    # w.grad = 1 \n",
    "print(b.grad)    # b.grad = 1 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "PyTorch: Tensor和autograd\n",
    "-------------------------------\n",
    "\n",
    "PyTorch的一个重要功能就是autograd，也就是说只要定义了forward pass(前向神经网络)，计算了loss之后，PyTorch可以自动求导计算模型所有参数的梯度。\n",
    "\n",
    "一个PyTorch的Tensor表示计算图中的一个节点。如果``x``是一个Tensor并且``x.requires_grad=True``那么``x.grad``是另一个储存着``x``当前梯度(相对于一个scalar，常常是loss)的向量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 31590738.0\n",
      "1 34389704.0\n",
      "2 44504280.0\n",
      "3 52598508.0\n",
      "4 46752264.0\n",
      "5 27227634.0\n",
      "6 10779343.0\n",
      "7 3889138.75\n",
      "8 1856397.875\n",
      "9 1232127.25\n",
      "10 967278.5\n",
      "11 806383.9375\n",
      "12 687169.25\n",
      "13 591936.25\n",
      "14 513579.40625\n",
      "15 448339.5\n",
      "16 393390.71875\n",
      "17 346772.71875\n",
      "18 306952.625\n",
      "19 272743.90625\n",
      "20 243250.578125\n",
      "21 217760.4375\n",
      "22 195513.75\n",
      "23 176012.4375\n",
      "24 158848.59375\n",
      "25 143694.4375\n",
      "26 130272.53125\n",
      "27 118357.1328125\n",
      "28 107732.5625\n",
      "29 98245.9296875\n",
      "30 89754.4375\n",
      "31 82145.9765625\n",
      "32 75299.703125\n",
      "33 69130.7265625\n",
      "34 63549.09375\n",
      "35 58498.18359375\n",
      "36 53914.7421875\n",
      "37 49751.984375\n",
      "38 45963.8515625\n",
      "39 42512.19140625\n",
      "40 39364.1484375\n",
      "41 36486.7421875\n",
      "42 33852.94921875\n",
      "43 31441.951171875\n",
      "44 29230.11328125\n",
      "45 27200.080078125\n",
      "46 25335.595703125\n",
      "47 23618.97265625\n",
      "48 22036.193359375\n",
      "49 20575.412109375\n",
      "50 19227.5078125\n",
      "51 17980.865234375\n",
      "52 16826.919921875\n",
      "53 15756.392578125\n",
      "54 14762.513671875\n",
      "55 13839.58203125\n",
      "56 12981.9228515625\n",
      "57 12184.3896484375\n",
      "58 11442.140625\n",
      "59 10750.8681640625\n",
      "60 10106.751953125\n",
      "61 9505.8720703125\n",
      "62 8944.9736328125\n",
      "63 8420.947265625\n",
      "64 7931.27734375\n",
      "65 7473.25\n",
      "66 7044.6455078125\n",
      "67 6643.35693359375\n",
      "68 6267.51123046875\n",
      "69 5915.16064453125\n",
      "70 5584.55615234375\n",
      "71 5274.3408203125\n",
      "72 4983.08349609375\n",
      "73 4709.5224609375\n",
      "74 4452.46484375\n",
      "75 4210.7333984375\n",
      "76 3983.40234375\n",
      "77 3769.482177734375\n",
      "78 3568.060302734375\n",
      "79 3378.383056640625\n",
      "80 3199.61962890625\n",
      "81 3031.24169921875\n",
      "82 2872.541748046875\n",
      "83 2722.83056640625\n",
      "84 2581.67431640625\n",
      "85 2448.48193359375\n",
      "86 2322.647705078125\n",
      "87 2203.833251953125\n",
      "88 2091.537353515625\n",
      "89 1985.4141845703125\n",
      "90 1885.1256103515625\n",
      "91 1790.2205810546875\n",
      "92 1700.4173583984375\n",
      "93 1615.4560546875\n",
      "94 1535.025390625\n",
      "95 1458.8707275390625\n",
      "96 1386.747314453125\n",
      "97 1318.4639892578125\n",
      "98 1253.7381591796875\n",
      "99 1192.38330078125\n",
      "100 1134.2308349609375\n",
      "101 1079.12255859375\n",
      "102 1026.814697265625\n",
      "103 977.2050170898438\n",
      "104 930.1414184570312\n",
      "105 885.4822998046875\n",
      "106 843.100341796875\n",
      "107 802.8572387695312\n",
      "108 764.6590576171875\n",
      "109 728.3578491210938\n",
      "110 693.8634033203125\n",
      "111 661.0936279296875\n",
      "112 629.9700927734375\n",
      "113 600.3693237304688\n",
      "114 572.22705078125\n",
      "115 545.47802734375\n",
      "116 520.0562744140625\n",
      "117 495.85888671875\n",
      "118 472.8380432128906\n",
      "119 450.9430236816406\n",
      "120 430.1151428222656\n",
      "121 410.2862854003906\n",
      "122 391.4206237792969\n",
      "123 373.4564514160156\n",
      "124 356.3577880859375\n",
      "125 340.0721130371094\n",
      "126 324.5631103515625\n",
      "127 309.78851318359375\n",
      "128 295.7184143066406\n",
      "129 282.31243896484375\n",
      "130 269.5414733886719\n",
      "131 257.3679504394531\n",
      "132 245.76673889160156\n",
      "133 234.71142578125\n",
      "134 224.170166015625\n",
      "135 214.1173553466797\n",
      "136 204.53355407714844\n",
      "137 195.39988708496094\n",
      "138 186.68582153320312\n",
      "139 178.3736114501953\n",
      "140 170.44247436523438\n",
      "141 162.87950134277344\n",
      "142 155.6625518798828\n",
      "143 148.77378845214844\n",
      "144 142.2010955810547\n",
      "145 135.93218994140625\n",
      "146 129.94813537597656\n",
      "147 124.23357391357422\n",
      "148 118.77680969238281\n",
      "149 113.56990051269531\n",
      "150 108.5975341796875\n",
      "151 103.852783203125\n",
      "152 99.31798553466797\n",
      "153 94.98828125\n",
      "154 90.8508071899414\n",
      "155 86.90143585205078\n",
      "156 83.12843322753906\n",
      "157 79.5234603881836\n",
      "158 76.07766723632812\n",
      "159 72.78606414794922\n",
      "160 69.64219665527344\n",
      "161 66.63894653320312\n",
      "162 63.768882751464844\n",
      "163 61.02363204956055\n",
      "164 58.39965057373047\n",
      "165 55.89255142211914\n",
      "166 53.495147705078125\n",
      "167 51.20370864868164\n",
      "168 49.01222229003906\n",
      "169 46.91761016845703\n",
      "170 44.91386795043945\n",
      "171 42.998748779296875\n",
      "172 41.16694641113281\n",
      "173 39.41572189331055\n",
      "174 37.74125289916992\n",
      "175 36.139198303222656\n",
      "176 34.60701370239258\n",
      "177 33.140785217285156\n",
      "178 31.7379093170166\n",
      "179 30.396512985229492\n",
      "180 29.113128662109375\n",
      "181 27.88433837890625\n",
      "182 26.709943771362305\n",
      "183 25.584423065185547\n",
      "184 24.5084285736084\n",
      "185 23.477943420410156\n",
      "186 22.491390228271484\n",
      "187 21.548294067382812\n",
      "188 20.645343780517578\n",
      "189 19.78101348876953\n",
      "190 18.95306396484375\n",
      "191 18.160476684570312\n",
      "192 17.402193069458008\n",
      "193 16.67613410949707\n",
      "194 15.980989456176758\n",
      "195 15.31508731842041\n",
      "196 14.677319526672363\n",
      "197 14.066808700561523\n",
      "198 13.482386589050293\n",
      "199 12.92292594909668\n",
      "200 12.386455535888672\n",
      "201 11.873013496398926\n",
      "202 11.381402969360352\n",
      "203 10.910120964050293\n",
      "204 10.459120750427246\n",
      "205 10.027090072631836\n",
      "206 9.613011360168457\n",
      "207 9.216401100158691\n",
      "208 8.836578369140625\n",
      "209 8.472214698791504\n",
      "210 8.123515129089355\n",
      "211 7.7894086837768555\n",
      "212 7.469310760498047\n",
      "213 7.162317752838135\n",
      "214 6.868289470672607\n",
      "215 6.586775779724121\n",
      "216 6.316912651062012\n",
      "217 6.058456897735596\n",
      "218 5.810366630554199\n",
      "219 5.572868824005127\n",
      "220 5.345065116882324\n",
      "221 5.1266560554504395\n",
      "222 4.917425155639648\n",
      "223 4.716660976409912\n",
      "224 4.524556636810303\n",
      "225 4.340378761291504\n",
      "226 4.163658618927002\n",
      "227 3.994236469268799\n",
      "228 3.8317785263061523\n",
      "229 3.6761226654052734\n",
      "230 3.5268383026123047\n",
      "231 3.383906126022339\n",
      "232 3.2465603351593018\n",
      "233 3.1150259971618652\n",
      "234 2.988828420639038\n",
      "235 2.8679425716400146\n",
      "236 2.7518584728240967\n",
      "237 2.640673875808716\n",
      "238 2.5340168476104736\n",
      "239 2.431691884994507\n",
      "240 2.3335652351379395\n",
      "241 2.2394731044769287\n",
      "242 2.1492316722869873\n",
      "243 2.0626134872436523\n",
      "244 1.979622721672058\n",
      "245 1.8999762535095215\n",
      "246 1.823598861694336\n",
      "247 1.7503103017807007\n",
      "248 1.6799509525299072\n",
      "249 1.6125677824020386\n",
      "250 1.5478930473327637\n",
      "251 1.4858715534210205\n",
      "252 1.4262574911117554\n",
      "253 1.3690857887268066\n",
      "254 1.3144354820251465\n",
      "255 1.2618528604507446\n",
      "256 1.2113248109817505\n",
      "257 1.1629149913787842\n",
      "258 1.116491675376892\n",
      "259 1.0719397068023682\n",
      "260 1.0291515588760376\n",
      "261 0.9881429672241211\n",
      "262 0.948744535446167\n",
      "263 0.9109649658203125\n",
      "264 0.8747091293334961\n",
      "265 0.8399428725242615\n",
      "266 0.8065575957298279\n",
      "267 0.7745365500450134\n",
      "268 0.743747353553772\n",
      "269 0.7142344117164612\n",
      "270 0.6858413219451904\n",
      "271 0.6586267352104187\n",
      "272 0.6324704885482788\n",
      "273 0.6074603796005249\n",
      "274 0.5833799242973328\n",
      "275 0.5603098273277283\n",
      "276 0.538133442401886\n",
      "277 0.5167998671531677\n",
      "278 0.4963699281215668\n",
      "279 0.47679442167282104\n",
      "280 0.4579639434814453\n",
      "281 0.4399140477180481\n",
      "282 0.4225271940231323\n",
      "283 0.4058454930782318\n",
      "284 0.3898758292198181\n",
      "285 0.37450703978538513\n",
      "286 0.3596993088722229\n",
      "287 0.34558868408203125\n",
      "288 0.3319928050041199\n",
      "289 0.3188936114311218\n",
      "290 0.3063645660877228\n",
      "291 0.2943674921989441\n",
      "292 0.282818078994751\n",
      "293 0.27166399359703064\n",
      "294 0.26100337505340576\n",
      "295 0.2508018910884857\n",
      "296 0.24095848202705383\n",
      "297 0.23150235414505005\n",
      "298 0.22243773937225342\n",
      "299 0.2137538194656372\n",
      "300 0.20537863671779633\n",
      "301 0.19736173748970032\n",
      "302 0.18962986767292023\n",
      "303 0.18225198984146118\n",
      "304 0.1751105785369873\n",
      "305 0.16825076937675476\n",
      "306 0.16170242428779602\n",
      "307 0.15539704263210297\n",
      "308 0.1493482142686844\n",
      "309 0.14353060722351074\n",
      "310 0.13794031739234924\n",
      "311 0.13258764147758484\n",
      "312 0.1274181455373764\n",
      "313 0.12247373908758163\n",
      "314 0.11770471930503845\n",
      "315 0.1131085455417633\n",
      "316 0.10872947424650192\n",
      "317 0.10449523478746414\n",
      "318 0.10042322427034378\n",
      "319 0.096539705991745\n",
      "320 0.09280091524124146\n",
      "321 0.08919163793325424\n",
      "322 0.0857444480061531\n",
      "323 0.08242055028676987\n",
      "324 0.0792209729552269\n",
      "325 0.07614120841026306\n",
      "326 0.07320591062307358\n",
      "327 0.07038237899541855\n",
      "328 0.06766688823699951\n",
      "329 0.06505295634269714\n",
      "330 0.06254184246063232\n",
      "331 0.06012542173266411\n",
      "332 0.05779905244708061\n",
      "333 0.05556876212358475\n",
      "334 0.05343194305896759\n",
      "335 0.05136372521519661\n",
      "336 0.04940014332532883\n",
      "337 0.047491997480392456\n",
      "338 0.04566337913274765\n",
      "339 0.043917763978242874\n",
      "340 0.04224282503128052\n",
      "341 0.04061662033200264\n",
      "342 0.039049748331308365\n",
      "343 0.037563201040029526\n",
      "344 0.036115214228630066\n",
      "345 0.03474092110991478\n",
      "346 0.03340528532862663\n",
      "347 0.0321282260119915\n",
      "348 0.03089768998324871\n",
      "349 0.02971574291586876\n",
      "350 0.028576789423823357\n",
      "351 0.02748893015086651\n",
      "352 0.026447011157870293\n",
      "353 0.025437500327825546\n",
      "354 0.024477161467075348\n",
      "355 0.02354368567466736\n",
      "356 0.02265477180480957\n",
      "357 0.021796781569719315\n",
      "358 0.020970573648810387\n",
      "359 0.02017371729016304\n",
      "360 0.019415026530623436\n",
      "361 0.01868613064289093\n",
      "362 0.017972717061638832\n",
      "363 0.017293401062488556\n",
      "364 0.01664922386407852\n",
      "365 0.01602707803249359\n",
      "366 0.015429402701556683\n",
      "367 0.014847123995423317\n",
      "368 0.01428967621177435\n",
      "369 0.013754935935139656\n",
      "370 0.013242769055068493\n",
      "371 0.012747152708470821\n",
      "372 0.012274167500436306\n",
      "373 0.011820298619568348\n",
      "374 0.01137969084084034\n",
      "375 0.01095337513834238\n",
      "376 0.010552221909165382\n",
      "377 0.010165980085730553\n",
      "378 0.009792990982532501\n",
      "379 0.009431964717805386\n",
      "380 0.009089745581150055\n",
      "381 0.0087546082213521\n",
      "382 0.008431270718574524\n",
      "383 0.008125036023557186\n",
      "384 0.007833220064640045\n",
      "385 0.007543003186583519\n",
      "386 0.007273838389664888\n",
      "387 0.007010471075773239\n",
      "388 0.006761929951608181\n",
      "389 0.006515162996947765\n",
      "390 0.0062854718416929245\n",
      "391 0.006053847726434469\n",
      "392 0.005839650984853506\n",
      "393 0.0056365784257650375\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "394 0.005430158693343401\n",
      "395 0.005243257619440556\n",
      "396 0.005058295093476772\n",
      "397 0.0048800683580338955\n",
      "398 0.004707938991487026\n",
      "399 0.004541801754385233\n",
      "400 0.004385354463011026\n",
      "401 0.0042332010343670845\n",
      "402 0.0040851193480193615\n",
      "403 0.003942274488508701\n",
      "404 0.003809330752119422\n",
      "405 0.0036788880825042725\n",
      "406 0.0035530496388673782\n",
      "407 0.0034328829497098923\n",
      "408 0.003316469956189394\n",
      "409 0.0032058244105428457\n",
      "410 0.003095718566328287\n",
      "411 0.002996482653543353\n",
      "412 0.002896404592320323\n",
      "413 0.002801347989588976\n",
      "414 0.0027062646113336086\n",
      "415 0.0026161009445786476\n",
      "416 0.002530781552195549\n",
      "417 0.002449025632813573\n",
      "418 0.002370838774368167\n",
      "419 0.002294242149218917\n",
      "420 0.002220114693045616\n",
      "421 0.002151642693206668\n",
      "422 0.0020829373970627785\n",
      "423 0.0020190104842185974\n",
      "424 0.0019563380628824234\n",
      "425 0.0018947365460917354\n",
      "426 0.0018343634437769651\n",
      "427 0.0017779992194846272\n",
      "428 0.0017241643508896232\n",
      "429 0.001670036930590868\n",
      "430 0.0016198739176616073\n",
      "431 0.0015696510672569275\n",
      "432 0.0015243508387356997\n",
      "433 0.0014775173040106893\n",
      "434 0.0014329483965411782\n",
      "435 0.0013928760308772326\n",
      "436 0.0013534030877053738\n",
      "437 0.0013135093031451106\n",
      "438 0.0012740943348035216\n",
      "439 0.0012363230343908072\n",
      "440 0.001201886567287147\n",
      "441 0.001167844282463193\n",
      "442 0.001134263351559639\n",
      "443 0.0011019724188372493\n",
      "444 0.0010723149171099067\n",
      "445 0.0010419739410281181\n",
      "446 0.0010126412380486727\n",
      "447 0.0009848815388977528\n",
      "448 0.0009563455241732299\n",
      "449 0.0009308728040196002\n",
      "450 0.00090641004499048\n",
      "451 0.0008831481100060046\n",
      "452 0.0008599660359323025\n",
      "453 0.0008373564924113452\n",
      "454 0.0008155326941050589\n",
      "455 0.000794055056758225\n",
      "456 0.0007728718337602913\n",
      "457 0.0007523726671934128\n",
      "458 0.0007338629802688956\n",
      "459 0.0007152488688006997\n",
      "460 0.000696428760420531\n",
      "461 0.0006795382359996438\n",
      "462 0.0006630202988162637\n",
      "463 0.0006455725524574518\n",
      "464 0.0006298987427726388\n",
      "465 0.0006149400724098086\n",
      "466 0.0006001737783662975\n",
      "467 0.0005853470065630972\n",
      "468 0.0005718616303056479\n",
      "469 0.0005581608274951577\n",
      "470 0.0005458852974697948\n",
      "471 0.0005324622616171837\n",
      "472 0.0005193008692003787\n",
      "473 0.0005074077052995563\n",
      "474 0.0004961146041750908\n",
      "475 0.0004842482740059495\n",
      "476 0.00047333139809779823\n",
      "477 0.00046403566375374794\n",
      "478 0.00045250554103404284\n",
      "479 0.000442950171418488\n",
      "480 0.0004327989590819925\n",
      "481 0.00042354772449471056\n",
      "482 0.000413733534514904\n",
      "483 0.00040549712139181793\n",
      "484 0.000397101161070168\n",
      "485 0.00038828636752441525\n",
      "486 0.0003802196588367224\n",
      "487 0.0003723677946254611\n",
      "488 0.00036461124545894563\n",
      "489 0.0003568623214960098\n",
      "490 0.0003494620032142848\n",
      "491 0.00034172856248915195\n",
      "492 0.00033552979584783316\n",
      "493 0.0003281632671132684\n",
      "494 0.00032163254218176007\n",
      "495 0.00031514454167336226\n",
      "496 0.00030899044941179454\n",
      "497 0.0003025623445864767\n",
      "498 0.0002974127419292927\n",
      "499 0.00029148231260478497\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "dtype = torch.float\n",
    "device = torch.device(\"cpu\")\n",
    "# device = torch.device(\"cuda:0\") # Uncomment this to run on GPU\n",
    "\n",
    "# N 是 batch size; D_in 是 input dimension;\n",
    "# H 是 hidden dimension; D_out 是 output dimension.\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# 创建随机的Tensor来保存输入和输出\n",
    "# 设定requires_grad=False表示在反向传播的时候我们不需要计算gradient\n",
    "x = torch.randn(N, D_in, device=device, dtype=dtype)\n",
    "y = torch.randn(N, D_out, device=device, dtype=dtype)\n",
    "\n",
    "# 创建随机的Tensor和权重。\n",
    "# 设置requires_grad=True表示我们希望反向传播的时候计算Tensor的gradient\n",
    "w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)\n",
    "w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "for t in range(500):\n",
    "    # 前向传播:通过Tensor预测y；这个和普通的神经网络的前向传播没有任何不同，\n",
    "    # 但是我们不需要保存网络的中间运算结果，因为我们不需要手动计算反向传播。\n",
    "    y_pred = x.mm(w1).clamp(min=0).mm(w2)\n",
    "\n",
    "    # 通过前向传播计算loss\n",
    "    # loss是一个形状为(1，)的Tensor\n",
    "    # loss.item()可以给我们返回一个loss的scalar\n",
    "    loss = (y_pred - y).pow(2).sum()\n",
    "    print(t, loss.item())\n",
    "\n",
    "    # PyTorch给我们提供了autograd的方法做反向传播。如果一个Tensor的requires_grad=True，\n",
    "    # backward会自动计算loss相对于每个Tensor的gradient。在backward之后，\n",
    "    # w1.grad和w2.grad会包含两个loss相对于两个Tensor的gradient信息。\n",
    "    loss.backward()\n",
    "\n",
    "    # 我们可以手动做gradient descent(后面我们会介绍自动的方法)。\n",
    "    # 用torch.no_grad()包含以下statements，因为w1和w2都是requires_grad=True，\n",
    "    # 但是在更新weights之后我们并不需要再做autograd。\n",
    "    # 另一种方法是在weight.data和weight.grad.data上做操作，这样就不会对grad产生影响。\n",
    "    # tensor.data会我们一个tensor，这个tensor和原来的tensor指向相同的内存空间，\n",
    "    # 但是不会记录计算图的历史。\n",
    "    with torch.no_grad():\n",
    "        w1 -= learning_rate * w1.grad\n",
    "        w2 -= learning_rate * w2.grad\n",
    "\n",
    "        # Manually zero the gradients after updating weights\n",
    "        w1.grad.zero_()\n",
    "        w2.grad.zero_()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "PyTorch: nn\n",
    "-----------\n",
    "\n",
    "\n",
    "这次我们使用PyTorch中nn这个库来构建网络。\n",
    "用PyTorch autograd来构建计算图和计算gradients，\n",
    "然后PyTorch会帮我们自动计算gradient。\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 616.8349609375\n",
      "1 570.4186401367188\n",
      "2 530.6421508789062\n",
      "3 495.7164001464844\n",
      "4 464.91497802734375\n",
      "5 437.1092834472656\n",
      "6 411.87066650390625\n",
      "7 388.5781555175781\n",
      "8 367.105224609375\n",
      "9 347.06768798828125\n",
      "10 328.3486328125\n",
      "11 310.6429748535156\n",
      "12 294.08880615234375\n",
      "13 278.54046630859375\n",
      "14 263.8558044433594\n",
      "15 249.83802795410156\n",
      "16 236.52313232421875\n",
      "17 223.8170166015625\n",
      "18 211.7015380859375\n",
      "19 200.13755798339844\n",
      "20 189.1465301513672\n",
      "21 178.6802520751953\n",
      "22 168.74122619628906\n",
      "23 159.31674194335938\n",
      "24 150.35125732421875\n",
      "25 141.79025268554688\n",
      "26 133.63401794433594\n",
      "27 125.89380645751953\n",
      "28 118.53340148925781\n",
      "29 111.54275512695312\n",
      "30 104.91582489013672\n",
      "31 98.65790557861328\n",
      "32 92.7421646118164\n",
      "33 87.18020629882812\n",
      "34 81.94192504882812\n",
      "35 77.01036834716797\n",
      "36 72.3639144897461\n",
      "37 67.99095916748047\n",
      "38 63.88977813720703\n",
      "39 60.036468505859375\n",
      "40 56.426231384277344\n",
      "41 53.05012512207031\n",
      "42 49.88925552368164\n",
      "43 46.92338943481445\n",
      "44 44.14652633666992\n",
      "45 41.54481887817383\n",
      "46 39.10710144042969\n",
      "47 36.83108139038086\n",
      "48 34.69960403442383\n",
      "49 32.706153869628906\n",
      "50 30.837730407714844\n",
      "51 29.088411331176758\n",
      "52 27.445863723754883\n",
      "53 25.909109115600586\n",
      "54 24.46878433227539\n",
      "55 23.117572784423828\n",
      "56 21.849687576293945\n",
      "57 20.658483505249023\n",
      "58 19.538114547729492\n",
      "59 18.48484992980957\n",
      "60 17.494647979736328\n",
      "61 16.563005447387695\n",
      "62 15.686758995056152\n",
      "63 14.856837272644043\n",
      "64 14.07450008392334\n",
      "65 13.338075637817383\n",
      "66 12.644659042358398\n",
      "67 11.991114616394043\n",
      "68 11.375090599060059\n",
      "69 10.79391098022461\n",
      "70 10.2453031539917\n",
      "71 9.728269577026367\n",
      "72 9.24122142791748\n",
      "73 8.78144359588623\n",
      "74 8.347025871276855\n",
      "75 7.936522960662842\n",
      "76 7.547698497772217\n",
      "77 7.180155277252197\n",
      "78 6.832812786102295\n",
      "79 6.503701210021973\n",
      "80 6.192763328552246\n",
      "81 5.897329330444336\n",
      "82 5.617427349090576\n",
      "83 5.3520426750183105\n",
      "84 5.100094795227051\n",
      "85 4.860849380493164\n",
      "86 4.634103298187256\n",
      "87 4.419088363647461\n",
      "88 4.2145843505859375\n",
      "89 4.020376682281494\n",
      "90 3.8357491493225098\n",
      "91 3.660381317138672\n",
      "92 3.493623733520508\n",
      "93 3.3348333835601807\n",
      "94 3.1840157508850098\n",
      "95 3.040458917617798\n",
      "96 2.90377140045166\n",
      "97 2.7736332416534424\n",
      "98 2.6498191356658936\n",
      "99 2.531902313232422\n",
      "100 2.419586658477783\n",
      "101 2.312699317932129\n",
      "102 2.2109007835388184\n",
      "103 2.113172769546509\n",
      "104 2.0200531482696533\n",
      "105 1.931264042854309\n",
      "106 1.8465205430984497\n",
      "107 1.765866994857788\n",
      "108 1.6889408826828003\n",
      "109 1.6155205965042114\n",
      "110 1.545505404472351\n",
      "111 1.478671669960022\n",
      "112 1.4148930311203003\n",
      "113 1.354078769683838\n",
      "114 1.2960002422332764\n",
      "115 1.2405637502670288\n",
      "116 1.1877013444900513\n",
      "117 1.1373679637908936\n",
      "118 1.0893242359161377\n",
      "119 1.0433975458145142\n",
      "120 0.9995244741439819\n",
      "121 0.957534909248352\n",
      "122 0.9174259901046753\n",
      "123 0.8791078329086304\n",
      "124 0.8424510359764099\n",
      "125 0.8073628544807434\n",
      "126 0.7738264203071594\n",
      "127 0.7417508959770203\n",
      "128 0.7110787630081177\n",
      "129 0.6817481517791748\n",
      "130 0.6536678075790405\n",
      "131 0.6267886161804199\n",
      "132 0.6010832786560059\n",
      "133 0.5764719843864441\n",
      "134 0.55291348695755\n",
      "135 0.5303621292114258\n",
      "136 0.508778989315033\n",
      "137 0.4881012439727783\n",
      "138 0.46830493211746216\n",
      "139 0.4493491053581238\n",
      "140 0.4311933219432831\n",
      "141 0.41380244493484497\n",
      "142 0.3971446752548218\n",
      "143 0.3811851739883423\n",
      "144 0.3658958375453949\n",
      "145 0.35125595331192017\n",
      "146 0.33723005652427673\n",
      "147 0.3237687349319458\n",
      "148 0.3108976483345032\n",
      "149 0.2985706031322479\n",
      "150 0.28673914074897766\n",
      "151 0.27540862560272217\n",
      "152 0.26455509662628174\n",
      "153 0.2541574537754059\n",
      "154 0.24418404698371887\n",
      "155 0.23461730778217316\n",
      "156 0.2254522293806076\n",
      "157 0.2166563868522644\n",
      "158 0.2082248032093048\n",
      "159 0.20014120638370514\n",
      "160 0.1923948973417282\n",
      "161 0.18496574461460114\n",
      "162 0.17784026265144348\n",
      "163 0.17100077867507935\n",
      "164 0.16445767879486084\n",
      "165 0.15817493200302124\n",
      "166 0.1521458625793457\n",
      "167 0.14637260138988495\n",
      "168 0.14082613587379456\n",
      "169 0.13550066947937012\n",
      "170 0.1303861439228058\n",
      "171 0.12548044323921204\n",
      "172 0.12077119201421738\n",
      "173 0.11624855548143387\n",
      "174 0.11190006136894226\n",
      "175 0.10772517323493958\n",
      "176 0.10371194034814835\n",
      "177 0.0998566597700119\n",
      "178 0.09615585952997208\n",
      "179 0.09260128438472748\n",
      "180 0.08918008953332901\n",
      "181 0.08589056134223938\n",
      "182 0.08273126929998398\n",
      "183 0.07969754189252853\n",
      "184 0.076775923371315\n",
      "185 0.07396624237298965\n",
      "186 0.07126504927873611\n",
      "187 0.0686689168214798\n",
      "188 0.06617266684770584\n",
      "189 0.06377072632312775\n",
      "190 0.06145996227860451\n",
      "191 0.05924048647284508\n",
      "192 0.057103026658296585\n",
      "193 0.055055245757102966\n",
      "194 0.05308816209435463\n",
      "195 0.05119483917951584\n",
      "196 0.04937274008989334\n",
      "197 0.04762033745646477\n",
      "198 0.04593187943100929\n",
      "199 0.044306349009275436\n",
      "200 0.04274118319153786\n",
      "201 0.04123516380786896\n",
      "202 0.03978639841079712\n",
      "203 0.038389310240745544\n",
      "204 0.03704371303319931\n",
      "205 0.0357489138841629\n",
      "206 0.03450245037674904\n",
      "207 0.03330256789922714\n",
      "208 0.0321439690887928\n",
      "209 0.03102947771549225\n",
      "210 0.02995586395263672\n",
      "211 0.02892051823437214\n",
      "212 0.02792329527437687\n",
      "213 0.026961468160152435\n",
      "214 0.02603522501885891\n",
      "215 0.02514214999973774\n",
      "216 0.02428179606795311\n",
      "217 0.02345244586467743\n",
      "218 0.02265304885804653\n",
      "219 0.021882878616452217\n",
      "220 0.02113926038146019\n",
      "221 0.02042245678603649\n",
      "222 0.019731855019927025\n",
      "223 0.01906573213636875\n",
      "224 0.018423302099108696\n",
      "225 0.017803329974412918\n",
      "226 0.017205584794282913\n",
      "227 0.016628772020339966\n",
      "228 0.016072314232587814\n",
      "229 0.015535356476902962\n",
      "230 0.015017733909189701\n",
      "231 0.014518306590616703\n",
      "232 0.014036747626960278\n",
      "233 0.013571222312748432\n",
      "234 0.013121938332915306\n",
      "235 0.012688718736171722\n",
      "236 0.012270272709429264\n",
      "237 0.011866560205817223\n",
      "238 0.011476823128759861\n",
      "239 0.011100317351520061\n",
      "240 0.010736910626292229\n",
      "241 0.010386170819401741\n",
      "242 0.0100497892126441\n",
      "243 0.0097251171246171\n",
      "244 0.009411685168743134\n",
      "245 0.009108913131058216\n",
      "246 0.008816596120595932\n",
      "247 0.00853422749787569\n",
      "248 0.008261235430836678\n",
      "249 0.007997557520866394\n",
      "250 0.007742919027805328\n",
      "251 0.007496801670640707\n",
      "252 0.007259095553308725\n",
      "253 0.007029606495052576\n",
      "254 0.006807629019021988\n",
      "255 0.006593053694814444\n",
      "256 0.006385627202689648\n",
      "257 0.006185163743793964\n",
      "258 0.005991355516016483\n",
      "259 0.005803895648568869\n",
      "260 0.005622652359306812\n",
      "261 0.0054474519565701485\n",
      "262 0.005278183612972498\n",
      "263 0.005114411935210228\n",
      "264 0.004955946933478117\n",
      "265 0.004802867770195007\n",
      "266 0.004654645454138517\n",
      "267 0.004511530976742506\n",
      "268 0.004373411647975445\n",
      "269 0.00423995777964592\n",
      "270 0.004110764712095261\n",
      "271 0.003985690884292126\n",
      "272 0.003864638740196824\n",
      "273 0.0037474501878023148\n",
      "274 0.0036340728402137756\n",
      "275 0.0035244303289800882\n",
      "276 0.0034184216056019068\n",
      "277 0.0033156615681946278\n",
      "278 0.0032162207644432783\n",
      "279 0.003119939938187599\n",
      "280 0.0030267902184277773\n",
      "281 0.0029364691581577063\n",
      "282 0.002849065698683262\n",
      "283 0.00276445085182786\n",
      "284 0.0026824434753507376\n",
      "285 0.0026030712760984898\n",
      "286 0.002526269294321537\n",
      "287 0.002451810520142317\n",
      "288 0.0023796753957867622\n",
      "289 0.0023097810335457325\n",
      "290 0.002242131158709526\n",
      "291 0.002176533453166485\n",
      "292 0.002112946705892682\n",
      "293 0.002051342511549592\n",
      "294 0.0019916510209441185\n",
      "295 0.0019338340498507023\n",
      "296 0.0018778254743665457\n",
      "297 0.0018235259922221303\n",
      "298 0.0017708562081679702\n",
      "299 0.001719821011647582\n",
      "300 0.001670365920290351\n",
      "301 0.0016224351711571217\n",
      "302 0.001575930742546916\n",
      "303 0.0015308443689718843\n",
      "304 0.0014871100429445505\n",
      "305 0.0014447338180616498\n",
      "306 0.0014036260545253754\n",
      "307 0.0013637744123116136\n",
      "308 0.0013251130003482103\n",
      "309 0.00128761469386518\n",
      "310 0.00125124619808048\n",
      "311 0.0012159458128735423\n",
      "312 0.0011817403137683868\n",
      "313 0.0011485485592857003\n",
      "314 0.0011163217714056373\n",
      "315 0.0010850488906726241\n",
      "316 0.0010547296842560172\n",
      "317 0.0010253143263980746\n",
      "318 0.0009967893129214644\n",
      "319 0.0009690661099739373\n",
      "320 0.0009421803988516331\n",
      "321 0.0009160566260106862\n",
      "322 0.0008907159208320081\n",
      "323 0.0008661439060233533\n",
      "324 0.0008422775426879525\n",
      "325 0.0008190966327674687\n",
      "326 0.0007965879631228745\n",
      "327 0.0007747435010969639\n",
      "328 0.000753526808694005\n",
      "329 0.0007329185027629137\n",
      "330 0.000712935405317694\n",
      "331 0.0006935214041732252\n",
      "332 0.0006746658473275602\n",
      "333 0.0006563455681316555\n",
      "334 0.0006385522428900003\n",
      "335 0.000621272309217602\n",
      "336 0.0006044972687959671\n",
      "337 0.0005881927208974957\n",
      "338 0.00057236134307459\n",
      "339 0.0005569689092226326\n",
      "340 0.0005420160596258938\n",
      "341 0.000527493713889271\n",
      "342 0.000513377774041146\n",
      "343 0.0004996666684746742\n",
      "344 0.0004863426147494465\n",
      "345 0.00047339097363874316\n",
      "346 0.0004608244926203042\n",
      "347 0.0004486078687477857\n",
      "348 0.0004367205547168851\n",
      "349 0.0004251683712936938\n",
      "350 0.0004139319353271276\n",
      "351 0.0004030133131891489\n",
      "352 0.00039240249316208065\n",
      "353 0.00038208605838008225\n",
      "354 0.00037205807166174054\n",
      "355 0.0003623157390393317\n",
      "356 0.0003528340021148324\n",
      "357 0.00034362293081358075\n",
      "358 0.0003346629673615098\n",
      "359 0.0003259552177041769\n",
      "360 0.0003174722078256309\n",
      "361 0.00030922345467843115\n",
      "362 0.0003012036031577736\n",
      "363 0.00029340494074858725\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "364 0.00028583104722201824\n",
      "365 0.000278460793197155\n",
      "366 0.00027128090732730925\n",
      "367 0.00026430474827066064\n",
      "368 0.0002575131948105991\n",
      "369 0.0002509095938876271\n",
      "370 0.00024448230396956205\n",
      "371 0.00023822381626814604\n",
      "372 0.0002321432693861425\n",
      "373 0.00022622810502070934\n",
      "374 0.0002204657648690045\n",
      "375 0.0002148632047465071\n",
      "376 0.00020940121612511575\n",
      "377 0.00020409838180057704\n",
      "378 0.0001989272132050246\n",
      "379 0.00019389843509998173\n",
      "380 0.00018900231225416064\n",
      "381 0.0001842392230173573\n",
      "382 0.00017960301192943007\n",
      "383 0.00017508988094050437\n",
      "384 0.00017069902969524264\n",
      "385 0.00016641429101582617\n",
      "386 0.00016225305444095284\n",
      "387 0.000158198265125975\n",
      "388 0.00015424926823470742\n",
      "389 0.000150404914165847\n",
      "390 0.0001466635148972273\n",
      "391 0.00014301779447123408\n",
      "392 0.00013946628314442933\n",
      "393 0.00013601550017483532\n",
      "394 0.00013265143206808716\n",
      "395 0.00012937198334839195\n",
      "396 0.00012617645552381873\n",
      "397 0.0001230676716659218\n",
      "398 0.0001200390252051875\n",
      "399 0.00011708753299899399\n",
      "400 0.00011421682575019076\n",
      "401 0.00011141804861836135\n",
      "402 0.00010869121615542099\n",
      "403 0.00010603333066683263\n",
      "404 0.00010344652400817722\n",
      "405 0.00010092253796756268\n",
      "406 9.846295870374888e-05\n",
      "407 9.6071045845747e-05\n",
      "408 9.373520879307762e-05\n",
      "409 9.146131196757779e-05\n",
      "410 8.924482244765386e-05\n",
      "411 8.70852600201033e-05\n",
      "412 8.498283568769693e-05\n",
      "413 8.293260907521471e-05\n",
      "414 8.093572250800207e-05\n",
      "415 7.89878613431938e-05\n",
      "416 7.708741031819955e-05\n",
      "417 7.523671229137108e-05\n",
      "418 7.34319764887914e-05\n",
      "419 7.167273724917322e-05\n",
      "420 6.995951844146475e-05\n",
      "421 6.828828190919012e-05\n",
      "422 6.665829278063029e-05\n",
      "423 6.507106445496902e-05\n",
      "424 6.35203905403614e-05\n",
      "425 6.200941425049677e-05\n",
      "426 6.05389905103948e-05\n",
      "427 5.9100995713379234e-05\n",
      "428 5.7702251069713384e-05\n",
      "429 5.633640830637887e-05\n",
      "430 5.500561746885069e-05\n",
      "431 5.3707120969193056e-05\n",
      "432 5.2441897423705086e-05\n",
      "433 5.1206407079007477e-05\n",
      "434 5.000238525099121e-05\n",
      "435 4.882620123680681e-05\n",
      "436 4.76813547720667e-05\n",
      "437 4.656398596125655e-05\n",
      "438 4.5474393118638545e-05\n",
      "439 4.4412216084310785e-05\n",
      "440 4.3375530367484316e-05\n",
      "441 4.236115637468174e-05\n",
      "442 4.137431824347004e-05\n",
      "443 4.0410945075564086e-05\n",
      "444 3.94726412196178e-05\n",
      "445 3.8555674109375104e-05\n",
      "446 3.765980727621354e-05\n",
      "447 3.6788511351915076e-05\n",
      "448 3.593677683966234e-05\n",
      "449 3.510593160171993e-05\n",
      "450 3.4294465876882896e-05\n",
      "451 3.350316546857357e-05\n",
      "452 3.273158290539868e-05\n",
      "453 3.197851401637308e-05\n",
      "454 3.124314753222279e-05\n",
      "455 3.0524690373567864e-05\n",
      "456 2.982511796290055e-05\n",
      "457 2.9140464903321117e-05\n",
      "458 2.8473716156440787e-05\n",
      "459 2.7822752599604428e-05\n",
      "460 2.718799078138545e-05\n",
      "461 2.656610740814358e-05\n",
      "462 2.5960111088352278e-05\n",
      "463 2.536940883146599e-05\n",
      "464 2.479194699844811e-05\n",
      "465 2.4228789698099717e-05\n",
      "466 2.367901470279321e-05\n",
      "467 2.314150333404541e-05\n",
      "468 2.2617272406932898e-05\n",
      "469 2.2104968593339436e-05\n",
      "470 2.1604977519018576e-05\n",
      "471 2.111716639774386e-05\n",
      "472 2.0640816728700884e-05\n",
      "473 2.0175557438051328e-05\n",
      "474 1.9721246644621715e-05\n",
      "475 1.92775023606373e-05\n",
      "476 1.8843960788217373e-05\n",
      "477 1.8419739717501216e-05\n",
      "478 1.8007833205047064e-05\n",
      "479 1.760409759299364e-05\n",
      "480 1.7209487850777805e-05\n",
      "481 1.6824411432025954e-05\n",
      "482 1.6448399037471972e-05\n",
      "483 1.608097954886034e-05\n",
      "484 1.5722764146630652e-05\n",
      "485 1.537234493298456e-05\n",
      "486 1.5029500900709536e-05\n",
      "487 1.4695562640554272e-05\n",
      "488 1.4368548363563605e-05\n",
      "489 1.4049977835384198e-05\n",
      "490 1.3737889275944326e-05\n",
      "491 1.3433097592496779e-05\n",
      "492 1.3135630069882609e-05\n",
      "493 1.2845604032918345e-05\n",
      "494 1.256068117072573e-05\n",
      "495 1.2283581781957764e-05\n",
      "496 1.2012886145384982e-05\n",
      "497 1.1748711585823912e-05\n",
      "498 1.1490279575809836e-05\n",
      "499 1.1236619684495963e-05\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# N is batch size; D_in is input dimension;\n",
    "# H is hidden dimension; D_out is output dimension.\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# Create random Tensors to hold inputs and outputs\n",
    "x = torch.randn(N, D_in)\n",
    "y = torch.randn(N, D_out)\n",
    "\n",
    "# Use the nn package to define our model as a sequence of layers. nn.Sequential\n",
    "# is a Module which contains other Modules, and applies them in sequence to\n",
    "# produce its output. Each Linear Module computes output from input using a\n",
    "# linear function, and holds internal Tensors for its weight and bias.\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(D_in, H),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(H, D_out),\n",
    ")\n",
    "\n",
    "# The nn package also contains definitions of popular loss functions; in this\n",
    "# case we will use Mean Squared Error (MSE) as our loss function.\n",
    "loss_fn = torch.nn.MSELoss(reduction='sum')\n",
    "\n",
    "learning_rate = 1e-4\n",
    "for t in range(500):\n",
    "    # Forward pass: compute predicted y by passing x to the model. Module objects\n",
    "    # override the __call__ operator so you can call them like functions. When\n",
    "    # doing so you pass a Tensor of input data to the Module and it produces\n",
    "    # a Tensor of output data.\n",
    "    y_pred = model(x)\n",
    "\n",
    "    # Compute and print loss. We pass Tensors containing the predicted and true\n",
    "    # values of y, and the loss function returns a Tensor containing the\n",
    "    # loss.\n",
    "    loss = loss_fn(y_pred, y)\n",
    "    print(t, loss.item())\n",
    "\n",
    "    # Zero the gradients before running the backward pass.\n",
    "    model.zero_grad()\n",
    "\n",
    "    # Backward pass: compute gradient of the loss with respect to all the learnable\n",
    "    # parameters of the model. Internally, the parameters of each Module are stored\n",
    "    # in Tensors with requires_grad=True, so this call will compute gradients for\n",
    "    # all learnable parameters in the model.\n",
    "    loss.backward()\n",
    "\n",
    "    # Update the weights using gradient descent. Each parameter is a Tensor, so\n",
    "    # we can access its gradients like we did before.\n",
    "    with torch.no_grad():\n",
    "        for param in model.parameters():\n",
    "            param -= learning_rate * param.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "PyTorch: optim\n",
    "--------------\n",
    "\n",
    "这一次我们不再手动更新模型的weights,而是使用optim这个包来帮助我们更新参数。\n",
    "optim这个package提供了各种不同的模型优化方法，包括SGD+momentum, RMSProp, Adam等等。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 791.6784057617188\n",
      "1 772.9479370117188\n",
      "2 754.7570190429688\n",
      "3 737.1196899414062\n",
      "4 719.9917602539062\n",
      "5 703.4118041992188\n",
      "6 687.2720947265625\n",
      "7 671.5335083007812\n",
      "8 656.197021484375\n",
      "9 641.322265625\n",
      "10 626.919189453125\n",
      "11 612.9500732421875\n",
      "12 599.3975219726562\n",
      "13 586.327392578125\n",
      "14 573.5608520507812\n",
      "15 561.1778564453125\n",
      "16 549.1342163085938\n",
      "17 537.3661499023438\n",
      "18 525.8930053710938\n",
      "19 514.7115478515625\n",
      "20 503.76434326171875\n",
      "21 493.110107421875\n",
      "22 482.747802734375\n",
      "23 472.5677490234375\n",
      "24 462.6549072265625\n",
      "25 452.9748840332031\n",
      "26 443.5308532714844\n",
      "27 434.2815856933594\n",
      "28 425.22064208984375\n",
      "29 416.3988342285156\n",
      "30 407.7718505859375\n",
      "31 399.32879638671875\n",
      "32 391.1062927246094\n",
      "33 383.07781982421875\n",
      "34 375.2408752441406\n",
      "35 367.54071044921875\n",
      "36 359.98126220703125\n",
      "37 352.5944519042969\n",
      "38 345.3690185546875\n",
      "39 338.3002624511719\n",
      "40 331.3758544921875\n",
      "41 324.56463623046875\n",
      "42 317.8782958984375\n",
      "43 311.31622314453125\n",
      "44 304.8724670410156\n",
      "45 298.5438537597656\n",
      "46 292.3463439941406\n",
      "47 286.2505187988281\n",
      "48 280.26092529296875\n",
      "49 274.3931884765625\n",
      "50 268.6314697265625\n",
      "51 262.9659423828125\n",
      "52 257.3974914550781\n",
      "53 251.95529174804688\n",
      "54 246.6234893798828\n",
      "55 241.38490295410156\n",
      "56 236.2388153076172\n",
      "57 231.19229125976562\n",
      "58 226.2274627685547\n",
      "59 221.34634399414062\n",
      "60 216.5567169189453\n",
      "61 211.8523406982422\n",
      "62 207.2224884033203\n",
      "63 202.67787170410156\n",
      "64 198.20230102539062\n",
      "65 193.80657958984375\n",
      "66 189.49249267578125\n",
      "67 185.26438903808594\n",
      "68 181.1260223388672\n",
      "69 177.0719451904297\n",
      "70 173.08523559570312\n",
      "71 169.17050170898438\n",
      "72 165.3196258544922\n",
      "73 161.5393524169922\n",
      "74 157.82603454589844\n",
      "75 154.1715545654297\n",
      "76 150.57814025878906\n",
      "77 147.04661560058594\n",
      "78 143.57847595214844\n",
      "79 140.1695098876953\n",
      "80 136.83058166503906\n",
      "81 133.5507049560547\n",
      "82 130.3309326171875\n",
      "83 127.16106414794922\n",
      "84 124.03736877441406\n",
      "85 120.97663116455078\n",
      "86 117.97151947021484\n",
      "87 115.01973724365234\n",
      "88 112.12226104736328\n",
      "89 109.28099822998047\n",
      "90 106.49259948730469\n",
      "91 103.7560806274414\n",
      "92 101.0788345336914\n",
      "93 98.46456146240234\n",
      "94 95.8936538696289\n",
      "95 93.37211608886719\n",
      "96 90.9074478149414\n",
      "97 88.49383544921875\n",
      "98 86.1294174194336\n",
      "99 83.81201171875\n",
      "100 81.54666137695312\n",
      "101 79.3274917602539\n",
      "102 77.15277862548828\n",
      "103 75.02595520019531\n",
      "104 72.94225311279297\n",
      "105 70.90507507324219\n",
      "106 68.91564178466797\n",
      "107 66.96821594238281\n",
      "108 65.06224822998047\n",
      "109 63.201744079589844\n",
      "110 61.382633209228516\n",
      "111 59.60671615600586\n",
      "112 57.873958587646484\n",
      "113 56.179588317871094\n",
      "114 54.523807525634766\n",
      "115 52.906803131103516\n",
      "116 51.32786178588867\n",
      "117 49.788665771484375\n",
      "118 48.285526275634766\n",
      "119 46.82194137573242\n",
      "120 45.393489837646484\n",
      "121 44.00169372558594\n",
      "122 42.64366149902344\n",
      "123 41.320953369140625\n",
      "124 40.03276824951172\n",
      "125 38.77764892578125\n",
      "126 37.554481506347656\n",
      "127 36.362998962402344\n",
      "128 35.20222854614258\n",
      "129 34.07368850708008\n",
      "130 32.9746208190918\n",
      "131 31.906068801879883\n",
      "132 30.866662979125977\n",
      "133 29.854427337646484\n",
      "134 28.871971130371094\n",
      "135 27.915695190429688\n",
      "136 26.98732566833496\n",
      "137 26.086950302124023\n",
      "138 25.21206283569336\n",
      "139 24.362747192382812\n",
      "140 23.53691291809082\n",
      "141 22.73594093322754\n",
      "142 21.958166122436523\n",
      "143 21.20453453063965\n",
      "144 20.472272872924805\n",
      "145 19.762775421142578\n",
      "146 19.074495315551758\n",
      "147 18.407861709594727\n",
      "148 17.762025833129883\n",
      "149 17.135404586791992\n",
      "150 16.52802085876465\n",
      "151 15.940832138061523\n",
      "152 15.3715238571167\n",
      "153 14.819578170776367\n",
      "154 14.285402297973633\n",
      "155 13.767573356628418\n",
      "156 13.266966819763184\n",
      "157 12.782539367675781\n",
      "158 12.314498901367188\n",
      "159 11.861427307128906\n",
      "160 11.422765731811523\n",
      "161 10.999663352966309\n",
      "162 10.590702056884766\n",
      "163 10.195518493652344\n",
      "164 9.813480377197266\n",
      "165 9.444453239440918\n",
      "166 9.088181495666504\n",
      "167 8.743974685668945\n",
      "168 8.411545753479004\n",
      "169 8.090726852416992\n",
      "170 7.781579494476318\n",
      "171 7.482759475708008\n",
      "172 7.194526672363281\n",
      "173 6.916567325592041\n",
      "174 6.648494243621826\n",
      "175 6.389762878417969\n",
      "176 6.139918327331543\n",
      "177 5.899596214294434\n",
      "178 5.667441368103027\n",
      "179 5.443691253662109\n",
      "180 5.22829008102417\n",
      "181 5.020461559295654\n",
      "182 4.8205108642578125\n",
      "183 4.627894401550293\n",
      "184 4.442294120788574\n",
      "185 4.263693332672119\n",
      "186 4.091548919677734\n",
      "187 3.9257419109344482\n",
      "188 3.766235828399658\n",
      "189 3.6124517917633057\n",
      "190 3.4642839431762695\n",
      "191 3.3219587802886963\n",
      "192 3.18510103225708\n",
      "193 3.0533838272094727\n",
      "194 2.926755905151367\n",
      "195 2.8052432537078857\n",
      "196 2.688368082046509\n",
      "197 2.576014757156372\n",
      "198 2.4679884910583496\n",
      "199 2.3642969131469727\n",
      "200 2.264831066131592\n",
      "201 2.16910457611084\n",
      "202 2.077399253845215\n",
      "203 1.9893807172775269\n",
      "204 1.9050089120864868\n",
      "205 1.8239641189575195\n",
      "206 1.746296763420105\n",
      "207 1.6716521978378296\n",
      "208 1.6001081466674805\n",
      "209 1.531592607498169\n",
      "210 1.4656728506088257\n",
      "211 1.4025815725326538\n",
      "212 1.34219491481781\n",
      "213 1.2842515707015991\n",
      "214 1.2286741733551025\n",
      "215 1.175434947013855\n",
      "216 1.1243504285812378\n",
      "217 1.0754128694534302\n",
      "218 1.028563141822815\n",
      "219 0.9836430549621582\n",
      "220 0.9406015276908875\n",
      "221 0.899389386177063\n",
      "222 0.8599286079406738\n",
      "223 0.8221437335014343\n",
      "224 0.7859696745872498\n",
      "225 0.7512990832328796\n",
      "226 0.7181407809257507\n",
      "227 0.686427891254425\n",
      "228 0.6560250520706177\n",
      "229 0.6269176602363586\n",
      "230 0.5990850329399109\n",
      "231 0.5724562406539917\n",
      "232 0.5469314455986023\n",
      "233 0.5225721001625061\n",
      "234 0.4992505609989166\n",
      "235 0.4769216477870941\n",
      "236 0.4555879533290863\n",
      "237 0.4351666569709778\n",
      "238 0.41567420959472656\n",
      "239 0.3969923257827759\n",
      "240 0.3791610300540924\n",
      "241 0.3620989918708801\n",
      "242 0.3457787036895752\n",
      "243 0.3301834166049957\n",
      "244 0.3152735233306885\n",
      "245 0.30102694034576416\n",
      "246 0.2874011993408203\n",
      "247 0.2743793725967407\n",
      "248 0.26194536685943604\n",
      "249 0.2500665485858917\n",
      "250 0.23870456218719482\n",
      "251 0.2278587371110916\n",
      "252 0.21748881042003632\n",
      "253 0.2075870931148529\n",
      "254 0.19812825322151184\n",
      "255 0.18910430371761322\n",
      "256 0.18047408759593964\n",
      "257 0.1722312867641449\n",
      "258 0.1643613874912262\n",
      "259 0.15685199201107025\n",
      "260 0.1496802419424057\n",
      "261 0.14283005893230438\n",
      "262 0.13628333806991577\n",
      "263 0.13004115223884583\n",
      "264 0.12407437711954117\n",
      "265 0.11838021129369736\n",
      "266 0.11294573545455933\n",
      "267 0.10776441544294357\n",
      "268 0.10281281918287277\n",
      "269 0.09808588773012161\n",
      "270 0.09357728809118271\n",
      "271 0.08927135169506073\n",
      "272 0.08516616374254227\n",
      "273 0.08124646544456482\n",
      "274 0.07750718295574188\n",
      "275 0.07393348217010498\n",
      "276 0.07052432745695114\n",
      "277 0.06727367639541626\n",
      "278 0.06417549401521683\n",
      "279 0.061211828142404556\n",
      "280 0.05838524177670479\n",
      "281 0.055691372603178024\n",
      "282 0.05311958119273186\n",
      "283 0.050665080547332764\n",
      "284 0.04832402989268303\n",
      "285 0.046091243624687195\n",
      "286 0.04395892843604088\n",
      "287 0.04192692041397095\n",
      "288 0.03998580947518349\n",
      "289 0.0381360799074173\n",
      "290 0.03637010604143143\n",
      "291 0.034684911370277405\n",
      "292 0.0330788791179657\n",
      "293 0.03154633194208145\n",
      "294 0.03008444234728813\n",
      "295 0.028689226135611534\n",
      "296 0.027359401807188988\n",
      "297 0.02609046921133995\n",
      "298 0.02487858757376671\n",
      "299 0.023724785074591637\n",
      "300 0.02262282185256481\n",
      "301 0.0215724129229784\n",
      "302 0.020570827648043633\n",
      "303 0.01961461454629898\n",
      "304 0.018704809248447418\n",
      "305 0.017835130915045738\n",
      "306 0.01700596511363983\n",
      "307 0.016215775161981583\n",
      "308 0.015460998751223087\n",
      "309 0.014741620048880577\n",
      "310 0.014055773615837097\n",
      "311 0.013400538824498653\n",
      "312 0.01277637854218483\n",
      "313 0.012180116027593613\n",
      "314 0.011613084003329277\n",
      "315 0.011071483604609966\n",
      "316 0.010555664077401161\n",
      "317 0.010063034482300282\n",
      "318 0.009593896567821503\n",
      "319 0.009146292693912983\n",
      "320 0.00871907826513052\n",
      "321 0.008312968537211418\n",
      "322 0.0079236114397645\n",
      "323 0.007553379982709885\n",
      "324 0.007200181949883699\n",
      "325 0.006863656919449568\n",
      "326 0.00654265284538269\n",
      "327 0.006236561574041843\n",
      "328 0.005944761913269758\n",
      "329 0.005666198208928108\n",
      "330 0.005400847643613815\n",
      "331 0.0051480308175086975\n",
      "332 0.0049063521437346935\n",
      "333 0.004676346201449633\n",
      "334 0.004457048140466213\n",
      "335 0.00424786563962698\n",
      "336 0.00404843594878912\n",
      "337 0.003858234267681837\n",
      "338 0.0036770079750567675\n",
      "339 0.003504228312522173\n",
      "340 0.003339442191645503\n",
      "341 0.0031823033932596445\n",
      "342 0.003032673615962267\n",
      "343 0.00288983853533864\n",
      "344 0.0027537832502275705\n",
      "345 0.0026240183506160975\n",
      "346 0.0025004481431096792\n",
      "347 0.002382527105510235\n",
      "348 0.0022701220586895943\n",
      "349 0.002162923803552985\n",
      "350 0.002060842467471957\n",
      "351 0.001963510410860181\n",
      "352 0.0018707378767430782\n",
      "353 0.0017822845838963985\n",
      "354 0.001697997679002583\n",
      "355 0.0016178140649572015\n",
      "356 0.0015412119682878256\n",
      "357 0.0014681273605674505\n",
      "358 0.0013986037811264396\n",
      "359 0.0013322994345799088\n",
      "360 0.0012691081501543522\n",
      "361 0.0012088974472135305\n",
      "362 0.001151501783169806\n",
      "363 0.001096819993108511\n",
      "364 0.0010446718661114573\n",
      "365 0.0009950095554813743\n",
      "366 0.0009476402192376554\n",
      "367 0.000902548257727176\n",
      "368 0.0008595681865699589\n",
      "369 0.0008186014601960778\n",
      "370 0.0007795642595738173\n",
      "371 0.0007423567585647106\n",
      "372 0.0007069373968988657\n",
      "373 0.0006732249748893082\n",
      "374 0.0006410047644749284\n",
      "375 0.0006103675113990903\n",
      "376 0.0005811756127513945\n",
      "377 0.0005533915827982128\n",
      "378 0.0005268672248348594\n",
      "379 0.0005016201175749302\n",
      "380 0.00047757592983543873\n",
      "381 0.0004546679265331477\n",
      "382 0.0004328518407419324\n",
      "383 0.0004121344827581197\n",
      "384 0.0003922642790712416\n",
      "385 0.00037340185372158885\n",
      "386 0.000355449941707775\n",
      "387 0.0003383288567420095\n",
      "388 0.00032203507726080716\n",
      "389 0.000306513044051826\n",
      "390 0.0002917253877967596\n",
      "391 0.0002776390756480396\n",
      "392 0.0002642270701471716\n",
      "393 0.0002514584339223802\n",
      "394 0.00023929811140988022\n",
      "395 0.0002277194580528885\n",
      "396 0.0002166864142054692\n",
      "397 0.00020619173301383853\n",
      "398 0.0001961870730156079\n",
      "399 0.0001866663369582966\n",
      "400 0.0001776068238541484\n",
      "401 0.00016897440946195275\n",
      "402 0.00016075365419965237\n",
      "403 0.00015293085016310215\n",
      "404 0.0001454944722354412\n",
      "405 0.00013839844905305654\n",
      "406 0.0001316592242801562\n",
      "407 0.00012522179167717695\n",
      "408 0.00011911398178199306\n",
      "409 0.00011329452536301687\n",
      "410 0.0001077575288945809\n",
      "411 0.00010248417675029486\n",
      "412 9.74709982983768e-05\n",
      "413 9.269234578823671e-05\n",
      "414 8.815102046355605e-05\n",
      "415 8.382758096558973e-05\n",
      "416 7.971279410412535e-05\n",
      "417 7.579699740745127e-05\n",
      "418 7.207059388747439e-05\n",
      "419 6.852364458609372e-05\n",
      "420 6.515011045848951e-05\n",
      "421 6.194211891852319e-05\n",
      "422 5.8886504120891914e-05\n",
      "423 5.598169445875101e-05\n",
      "424 5.322420838638209e-05\n",
      "425 5.058886745246127e-05\n",
      "426 4.808947414858267e-05\n",
      "427 4.571105819195509e-05\n",
      "428 4.344765329733491e-05\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "429 4.1297284042229876e-05\n",
      "430 3.9249673136509955e-05\n",
      "431 3.7301993870642036e-05\n",
      "432 3.545052823028527e-05\n",
      "433 3.3691019780235365e-05\n",
      "434 3.2015552278608084e-05\n",
      "435 3.0420967959798872e-05\n",
      "436 2.890893301810138e-05\n",
      "437 2.74712383543374e-05\n",
      "438 2.6098401576746255e-05\n",
      "439 2.4797802325338125e-05\n",
      "440 2.3561869966215454e-05\n",
      "441 2.2385444026440382e-05\n",
      "442 2.1267582269501872e-05\n",
      "443 2.0203089661663398e-05\n",
      "444 1.9192844774806872e-05\n",
      "445 1.8230841305921786e-05\n",
      "446 1.7318037862423807e-05\n",
      "447 1.6449723261757754e-05\n",
      "448 1.562457691761665e-05\n",
      "449 1.4840144103800412e-05\n",
      "450 1.409533797414042e-05\n",
      "451 1.3387114449869841e-05\n",
      "452 1.2712825991911814e-05\n",
      "453 1.2073536709067412e-05\n",
      "454 1.1465210263850167e-05\n",
      "455 1.0887116332014557e-05\n",
      "456 1.0337735147913918e-05\n",
      "457 9.81609719019616e-06\n",
      "458 9.320682693214621e-06\n",
      "459 8.849255209497642e-06\n",
      "460 8.402344064961653e-06\n",
      "461 7.977385394042358e-06\n",
      "462 7.57272437112988e-06\n",
      "463 7.188868949015159e-06\n",
      "464 6.824670890637208e-06\n",
      "465 6.479081548604881e-06\n",
      "466 6.150191438791808e-06\n",
      "467 5.837410753883887e-06\n",
      "468 5.5405430430255365e-06\n",
      "469 5.259076715447009e-06\n",
      "470 4.991058631276246e-06\n",
      "471 4.736447863251669e-06\n",
      "472 4.495564553508302e-06\n",
      "473 4.266698852006812e-06\n",
      "474 4.049014023621567e-06\n",
      "475 3.842098976747366e-06\n",
      "476 3.6463525248109363e-06\n",
      "477 3.459550043771742e-06\n",
      "478 3.2829811971168965e-06\n",
      "479 3.114487753919093e-06\n",
      "480 2.955280706373742e-06\n",
      "481 2.804437144732219e-06\n",
      "482 2.6597372198011726e-06\n",
      "483 2.523453986214008e-06\n",
      "484 2.3942884581629187e-06\n",
      "485 2.271245421070489e-06\n",
      "486 2.1546265998040326e-06\n",
      "487 2.0439254058146616e-06\n",
      "488 1.938894001796143e-06\n",
      "489 1.8390732066109194e-06\n",
      "490 1.7443379647374968e-06\n",
      "491 1.654603806855448e-06\n",
      "492 1.569268874845875e-06\n",
      "493 1.4882393770676572e-06\n",
      "494 1.4111951713857707e-06\n",
      "495 1.3382320958044147e-06\n",
      "496 1.269111749024887e-06\n",
      "497 1.2034716974085313e-06\n",
      "498 1.1412265621402184e-06\n",
      "499 1.0824967375810957e-06\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# N is batch size; D_in is input dimension;\n",
    "# H is hidden dimension; D_out is output dimension.\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# Create random Tensors to hold inputs and outputs\n",
    "x = torch.randn(N, D_in)\n",
    "y = torch.randn(N, D_out)\n",
    "\n",
    "# Use the nn package to define our model and loss function.\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(D_in, H),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(H, D_out),\n",
    ")\n",
    "loss_fn = torch.nn.MSELoss(reduction='sum')\n",
    "\n",
    "# Use the optim package to define an Optimizer that will update the weights of\n",
    "# the model for us. Here we will use Adam; the optim package contains many other\n",
    "# optimization algoriths. The first argument to the Adam constructor tells the\n",
    "# optimizer which Tensors it should update.\n",
    "learning_rate = 1e-4\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "for t in range(500):\n",
    "    # Forward pass: compute predicted y by passing x to the model.\n",
    "    y_pred = model(x)\n",
    "\n",
    "    # Compute and print loss.\n",
    "    loss = loss_fn(y_pred, y)\n",
    "    print(t, loss.item())\n",
    "\n",
    "    # Before the backward pass, use the optimizer object to zero all of the\n",
    "    # gradients for the variables it will update (which are the learnable\n",
    "    # weights of the model). This is because by default, gradients are\n",
    "    # accumulated in buffers( i.e, not overwritten) whenever .backward()\n",
    "    # is called. Checkout docs of torch.autograd.backward for more details.\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    # Backward pass: compute gradient of the loss with respect to model\n",
    "    # parameters\n",
    "    loss.backward()\n",
    "\n",
    "    # Calling the step function on an Optimizer makes an update to its\n",
    "    # parameters\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "PyTorch: 自定义 nn Modules\n",
    "--------------------------\n",
    "\n",
    "我们可以定义一个模型，这个模型继承自nn.Module类。如果需要定义一个比Sequential模型更加复杂的模型，就需要定义nn.Module模型。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 656.6958618164062\n",
      "1 608.1090087890625\n",
      "2 566.172607421875\n",
      "3 529.2335815429688\n",
      "4 496.7382507324219\n",
      "5 467.453125\n",
      "6 440.5755310058594\n",
      "7 416.12872314453125\n",
      "8 393.6068420410156\n",
      "9 372.708251953125\n",
      "10 353.00006103515625\n",
      "11 334.477783203125\n",
      "12 316.97283935546875\n",
      "13 300.36737060546875\n",
      "14 284.6544189453125\n",
      "15 269.65936279296875\n",
      "16 255.33456420898438\n",
      "17 241.66688537597656\n",
      "18 228.60800170898438\n",
      "19 216.09536743164062\n",
      "20 204.13780212402344\n",
      "21 192.75645446777344\n",
      "22 181.89234924316406\n",
      "23 171.58370971679688\n",
      "24 161.7939453125\n",
      "25 152.4780731201172\n",
      "26 143.59371948242188\n",
      "27 135.14727783203125\n",
      "28 127.13992309570312\n",
      "29 119.55585479736328\n",
      "30 112.37797546386719\n",
      "31 105.62073516845703\n",
      "32 99.24383544921875\n",
      "33 93.24134826660156\n",
      "34 87.58341979980469\n",
      "35 82.25212860107422\n",
      "36 77.24210357666016\n",
      "37 72.55087280273438\n",
      "38 68.1427230834961\n",
      "39 64.00277709960938\n",
      "40 60.1308479309082\n",
      "41 56.49887466430664\n",
      "42 53.0952033996582\n",
      "43 49.906524658203125\n",
      "44 46.91959762573242\n",
      "45 44.11970520019531\n",
      "46 41.50297164916992\n",
      "47 39.062870025634766\n",
      "48 36.77130126953125\n",
      "49 34.62575149536133\n",
      "50 32.61510467529297\n",
      "51 30.732582092285156\n",
      "52 28.968435287475586\n",
      "53 27.313982009887695\n",
      "54 25.76266098022461\n",
      "55 24.30921173095703\n",
      "56 22.946903228759766\n",
      "57 21.669231414794922\n",
      "58 20.470420837402344\n",
      "59 19.347976684570312\n",
      "60 18.293415069580078\n",
      "61 17.30330467224121\n",
      "62 16.373027801513672\n",
      "63 15.501211166381836\n",
      "64 14.681639671325684\n",
      "65 13.909913063049316\n",
      "66 13.183857917785645\n",
      "67 12.499263763427734\n",
      "68 11.854406356811523\n",
      "69 11.247345924377441\n",
      "70 10.675314903259277\n",
      "71 10.13557243347168\n",
      "72 9.626264572143555\n",
      "73 9.145630836486816\n",
      "74 8.692134857177734\n",
      "75 8.26422119140625\n",
      "76 7.859929084777832\n",
      "77 7.478569030761719\n",
      "78 7.118776321411133\n",
      "79 6.778269290924072\n",
      "80 6.455151557922363\n",
      "81 6.149427890777588\n",
      "82 5.859961986541748\n",
      "83 5.586189270019531\n",
      "84 5.326669692993164\n",
      "85 5.081073760986328\n",
      "86 4.848598003387451\n",
      "87 4.628039836883545\n",
      "88 4.418812274932861\n",
      "89 4.22011137008667\n",
      "90 4.03148889541626\n",
      "91 3.85233211517334\n",
      "92 3.6820061206817627\n",
      "93 3.5200538635253906\n",
      "94 3.366295576095581\n",
      "95 3.220104694366455\n",
      "96 3.081094980239868\n",
      "97 2.9478847980499268\n",
      "98 2.8211557865142822\n",
      "99 2.700608491897583\n",
      "100 2.585635185241699\n",
      "101 2.476266860961914\n",
      "102 2.3720734119415283\n",
      "103 2.27278733253479\n",
      "104 2.178467273712158\n",
      "105 2.0886454582214355\n",
      "106 2.0029616355895996\n",
      "107 1.9212790727615356\n",
      "108 1.8433643579483032\n",
      "109 1.7690125703811646\n",
      "110 1.6980165243148804\n",
      "111 1.630236029624939\n",
      "112 1.5654228925704956\n",
      "113 1.5034924745559692\n",
      "114 1.444340467453003\n",
      "115 1.387866735458374\n",
      "116 1.3338878154754639\n",
      "117 1.2822437286376953\n",
      "118 1.232841968536377\n",
      "119 1.1855518817901611\n",
      "120 1.1402850151062012\n",
      "121 1.0969470739364624\n",
      "122 1.0554102659225464\n",
      "123 1.0156100988388062\n",
      "124 0.9774888753890991\n",
      "125 0.9409765601158142\n",
      "126 0.9060029983520508\n",
      "127 0.8724609017372131\n",
      "128 0.8402785062789917\n",
      "129 0.8092858195304871\n",
      "130 0.7795636653900146\n",
      "131 0.751034140586853\n",
      "132 0.7236625552177429\n",
      "133 0.697391927242279\n",
      "134 0.6721824407577515\n",
      "135 0.6479606628417969\n",
      "136 0.6246941685676575\n",
      "137 0.6023442149162292\n",
      "138 0.5808628797531128\n",
      "139 0.5602169036865234\n",
      "140 0.5403794050216675\n",
      "141 0.5213167071342468\n",
      "142 0.5029921531677246\n",
      "143 0.48536738753318787\n",
      "144 0.4684107005596161\n",
      "145 0.4521111249923706\n",
      "146 0.4364219009876251\n",
      "147 0.4213317036628723\n",
      "148 0.4068053662776947\n",
      "149 0.3928261399269104\n",
      "150 0.3793690800666809\n",
      "151 0.3664127290248871\n",
      "152 0.35393786430358887\n",
      "153 0.34193798899650574\n",
      "154 0.33036914467811584\n",
      "155 0.3192187547683716\n",
      "156 0.3084754943847656\n",
      "157 0.2981330454349518\n",
      "158 0.28816646337509155\n",
      "159 0.2785727083683014\n",
      "160 0.26933375000953674\n",
      "161 0.2604302763938904\n",
      "162 0.2518427073955536\n",
      "163 0.2435620129108429\n",
      "164 0.2355814278125763\n",
      "165 0.22787995636463165\n",
      "166 0.22044798731803894\n",
      "167 0.2132810354232788\n",
      "168 0.20636311173439026\n",
      "169 0.19969020783901215\n",
      "170 0.19325482845306396\n",
      "171 0.18703718483448029\n",
      "172 0.18104007840156555\n",
      "173 0.17524453997612\n",
      "174 0.16964828968048096\n",
      "175 0.16424523293972015\n",
      "176 0.1590285748243332\n",
      "177 0.15399070084095\n",
      "178 0.14912430942058563\n",
      "179 0.14442437887191772\n",
      "180 0.13988198339939117\n",
      "181 0.13549421727657318\n",
      "182 0.13124999403953552\n",
      "183 0.12715040147304535\n",
      "184 0.12318697571754456\n",
      "185 0.11935491114854813\n",
      "186 0.11565019190311432\n",
      "187 0.11206966638565063\n",
      "188 0.10860667377710342\n",
      "189 0.10525806993246078\n",
      "190 0.10202305018901825\n",
      "191 0.09889409691095352\n",
      "192 0.09586735814809799\n",
      "193 0.09293802082538605\n",
      "194 0.09010337293148041\n",
      "195 0.0873604416847229\n",
      "196 0.0847071036696434\n",
      "197 0.08213980495929718\n",
      "198 0.07965682446956635\n",
      "199 0.07725474238395691\n",
      "200 0.07492762058973312\n",
      "201 0.07267512381076813\n",
      "202 0.07049493491649628\n",
      "203 0.06838559359312057\n",
      "204 0.06634149700403214\n",
      "205 0.06436219811439514\n",
      "206 0.062446292489767075\n",
      "207 0.06059153750538826\n",
      "208 0.0587935708463192\n",
      "209 0.05705287307500839\n",
      "210 0.05536716431379318\n",
      "211 0.05373508855700493\n",
      "212 0.052152544260025024\n",
      "213 0.05062079429626465\n",
      "214 0.049136947840452194\n",
      "215 0.04769892245531082\n",
      "216 0.0463058203458786\n",
      "217 0.04495447129011154\n",
      "218 0.043645307421684265\n",
      "219 0.04237687215209007\n",
      "220 0.04114643111824989\n",
      "221 0.03995450586080551\n",
      "222 0.03880004957318306\n",
      "223 0.03768027573823929\n",
      "224 0.03659489378333092\n",
      "225 0.03554360195994377\n",
      "226 0.03452374413609505\n",
      "227 0.03353426977992058\n",
      "228 0.032575368881225586\n",
      "229 0.03164541721343994\n",
      "230 0.03074309229850769\n",
      "231 0.02986789681017399\n",
      "232 0.029019085690379143\n",
      "233 0.028195498511195183\n",
      "234 0.0273965485394001\n",
      "235 0.02662237547338009\n",
      "236 0.025871429592370987\n",
      "237 0.025142081081867218\n",
      "238 0.024434441700577736\n",
      "239 0.023747552186250687\n",
      "240 0.023081056773662567\n",
      "241 0.022434504702687263\n",
      "242 0.021806789562106133\n",
      "243 0.021197138354182243\n",
      "244 0.020606014877557755\n",
      "245 0.02003205008804798\n",
      "246 0.019474638625979424\n",
      "247 0.018934227526187897\n",
      "248 0.01840922422707081\n",
      "249 0.01789930835366249\n",
      "250 0.017403993755578995\n",
      "251 0.016923150047659874\n",
      "252 0.016456523910164833\n",
      "253 0.0160031970590353\n",
      "254 0.015562880784273148\n",
      "255 0.01513546984642744\n",
      "256 0.014720354229211807\n",
      "257 0.01431703194975853\n",
      "258 0.013925755396485329\n",
      "259 0.013545180670917034\n",
      "260 0.013175630941987038\n",
      "261 0.01281663216650486\n",
      "262 0.012467730790376663\n",
      "263 0.012128635309636593\n",
      "264 0.011799173429608345\n",
      "265 0.011479041539132595\n",
      "266 0.011168107390403748\n",
      "267 0.01086602546274662\n",
      "268 0.01057237759232521\n",
      "269 0.010287342593073845\n",
      "270 0.010010016150772572\n",
      "271 0.009740477427840233\n",
      "272 0.009478610940277576\n",
      "273 0.009224051609635353\n",
      "274 0.008976546116173267\n",
      "275 0.00873596128076315\n",
      "276 0.008502164855599403\n",
      "277 0.008274776861071587\n",
      "278 0.00805367436259985\n",
      "279 0.007838784717023373\n",
      "280 0.007630025502294302\n",
      "281 0.007427022326737642\n",
      "282 0.0072296857833862305\n",
      "283 0.007037700153887272\n",
      "284 0.006851076614111662\n",
      "285 0.00666955066844821\n",
      "286 0.006492926273494959\n",
      "287 0.00632116524502635\n",
      "288 0.006154042202979326\n",
      "289 0.005991632118821144\n",
      "290 0.0058336709626019\n",
      "291 0.0056800455786287785\n",
      "292 0.005530708469450474\n",
      "293 0.005385328084230423\n",
      "294 0.0052438536658883095\n",
      "295 0.005106212105602026\n",
      "296 0.004972323775291443\n",
      "297 0.004842082504183054\n",
      "298 0.004715400282293558\n",
      "299 0.004592181649059057\n",
      "300 0.004472256172448397\n",
      "301 0.004355618264526129\n",
      "302 0.0042420197278261185\n",
      "303 0.004131658002734184\n",
      "304 0.004024124704301357\n",
      "305 0.003919483628123999\n",
      "306 0.003817657707259059\n",
      "307 0.0037186346016824245\n",
      "308 0.003622243879362941\n",
      "309 0.003528367727994919\n",
      "310 0.0034370282664895058\n",
      "311 0.0033481812570244074\n",
      "312 0.0032616364769637585\n",
      "313 0.0031774123199284077\n",
      "314 0.00309544475749135\n",
      "315 0.003015762660652399\n",
      "316 0.002938046120107174\n",
      "317 0.002862409455701709\n",
      "318 0.002788822166621685\n",
      "319 0.0027171699330210686\n",
      "320 0.0026473761536180973\n",
      "321 0.002579432912170887\n",
      "322 0.002513323212042451\n",
      "323 0.0024489827919751406\n",
      "324 0.0023862975649535656\n",
      "325 0.0023252193350344896\n",
      "326 0.0022658223751932383\n",
      "327 0.0022080307826399803\n",
      "328 0.0021516724955290556\n",
      "329 0.002096814103424549\n",
      "330 0.002043382963165641\n",
      "331 0.0019913306459784508\n",
      "332 0.0019406524952501059\n",
      "333 0.0018913011299446225\n",
      "334 0.0018432465149089694\n",
      "335 0.0017964182188734412\n",
      "336 0.0017508040182292461\n",
      "337 0.001706396578811109\n",
      "338 0.001663187169469893\n",
      "339 0.0016211027977988124\n",
      "340 0.0015800849068909883\n",
      "341 0.0015401256969198585\n",
      "342 0.0015011945506557822\n",
      "343 0.001463254913687706\n",
      "344 0.0014263042248785496\n",
      "345 0.001390322926454246\n",
      "346 0.0013552795862779021\n",
      "347 0.0013211170444265008\n",
      "348 0.0012878385605290532\n",
      "349 0.0012554213171824813\n",
      "350 0.0012238684576004744\n",
      "351 0.0011931274784728885\n",
      "352 0.0011631487868726254\n",
      "353 0.0011339503107592463\n",
      "354 0.0011054981732740998\n",
      "355 0.0010777906281873584\n",
      "356 0.0010507676051929593\n",
      "357 0.0010244473814964294\n",
      "358 0.0009987998055294156\n",
      "359 0.00097382947569713\n",
      "360 0.0009494827827438712\n",
      "361 0.0009257435449399054\n",
      "362 0.000902607396710664\n",
      "363 0.0008801089134067297\n",
      "364 0.0008581369183957577\n",
      "365 0.0008367350674234331\n",
      "366 0.0008158780983649194\n",
      "367 0.0007955483160912991\n",
      "368 0.0007757404237054288\n",
      "369 0.0007564350962638855\n",
      "370 0.000737626978661865\n",
      "371 0.0007192868506535888\n",
      "372 0.0007014275179244578\n",
      "373 0.0006839964189566672\n",
      "374 0.0006670261500403285\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "375 0.0006505012279376388\n",
      "376 0.0006343786371871829\n",
      "377 0.0006186614627949893\n",
      "378 0.0006033276440575719\n",
      "379 0.0005883832345716655\n",
      "380 0.0005738206673413515\n",
      "381 0.0005596213741227984\n",
      "382 0.0005457888473756611\n",
      "383 0.0005322962533682585\n",
      "384 0.0005191444652155042\n",
      "385 0.000506328884512186\n",
      "386 0.0004938290221616626\n",
      "387 0.00048163760220631957\n",
      "388 0.0004697689728345722\n",
      "389 0.0004582055553328246\n",
      "390 0.00044691533548757434\n",
      "391 0.00043590739369392395\n",
      "392 0.0004251690406817943\n",
      "393 0.0004147063591517508\n",
      "394 0.00040450665983371437\n",
      "395 0.0003945553908124566\n",
      "396 0.0003848606429528445\n",
      "397 0.00037539892946369946\n",
      "398 0.0003661849768832326\n",
      "399 0.00035720854066312313\n",
      "400 0.000348439411027357\n",
      "401 0.0003398970584385097\n",
      "402 0.00033156739664264023\n",
      "403 0.0003234421892557293\n",
      "404 0.0003155224258080125\n",
      "405 0.00030779733788222075\n",
      "406 0.0003002593875862658\n",
      "407 0.00029291390092112124\n",
      "408 0.00028574312455020845\n",
      "409 0.0002787590492516756\n",
      "410 0.000271946337306872\n",
      "411 0.00026530082686804235\n",
      "412 0.00025882109184749424\n",
      "413 0.0002525025047361851\n",
      "414 0.00024633959401398897\n",
      "415 0.00024033462977968156\n",
      "416 0.00023447422427125275\n",
      "417 0.00022875398281030357\n",
      "418 0.00022317911498248577\n",
      "419 0.00021774203923996538\n",
      "420 0.00021243662922643125\n",
      "421 0.00020726659568026662\n",
      "422 0.0002022238913923502\n",
      "423 0.00019730582425836474\n",
      "424 0.00019250763580203056\n",
      "425 0.0001878279581433162\n",
      "426 0.00018326277495361865\n",
      "427 0.00017881377425510436\n",
      "428 0.0001744656328810379\n",
      "429 0.00017023469263222069\n",
      "430 0.0001660993875702843\n",
      "431 0.0001620692346477881\n",
      "432 0.00015813524078112096\n",
      "433 0.00015429955965373665\n",
      "434 0.0001505605468992144\n",
      "435 0.00014691019896417856\n",
      "436 0.00014335215382743627\n",
      "437 0.00013987701095174998\n",
      "438 0.0001364911295240745\n",
      "439 0.00013318308629095554\n",
      "440 0.00012996031728107482\n",
      "441 0.00012681707448791713\n",
      "442 0.0001237515825778246\n",
      "443 0.00012076242273906246\n",
      "444 0.0001178424499812536\n",
      "445 0.0001149943345808424\n",
      "446 0.00011221617023693398\n",
      "447 0.00010950590512948111\n",
      "448 0.00010686153109418228\n",
      "449 0.00010428134555695578\n",
      "450 0.00010176430077990517\n",
      "451 9.930722444551066e-05\n",
      "452 9.691028390079737e-05\n",
      "453 9.457595297135413e-05\n",
      "454 9.22925173654221e-05\n",
      "455 9.006822801893577e-05\n",
      "456 8.79026047186926e-05\n",
      "457 8.57846753206104e-05\n",
      "458 8.371831063413993e-05\n",
      "459 8.170259388862178e-05\n",
      "460 7.973532046889886e-05\n",
      "461 7.781643944326788e-05\n",
      "462 7.594288035761565e-05\n",
      "463 7.411641854560003e-05\n",
      "464 7.233369251480326e-05\n",
      "465 7.059406925691292e-05\n",
      "466 6.889844371471554e-05\n",
      "467 6.724218837916851e-05\n",
      "468 6.56265692668967e-05\n",
      "469 6.405053864000365e-05\n",
      "470 6.251001468626782e-05\n",
      "471 6.100907557993196e-05\n",
      "472 5.9547543060034513e-05\n",
      "473 5.8119461755268276e-05\n",
      "474 5.6722430599620566e-05\n",
      "475 5.536193202715367e-05\n",
      "476 5.403558679972775e-05\n",
      "477 5.274035356706008e-05\n",
      "478 5.1476214139256626e-05\n",
      "479 5.0242502766195685e-05\n",
      "480 4.903853914584033e-05\n",
      "481 4.78645451948978e-05\n",
      "482 4.6718425437575206e-05\n",
      "483 4.559816079563461e-05\n",
      "484 4.4506497943075374e-05\n",
      "485 4.344211265561171e-05\n",
      "486 4.240254929754883e-05\n",
      "487 4.13884044974111e-05\n",
      "488 4.039857958559878e-05\n",
      "489 3.943321280530654e-05\n",
      "490 3.848948108498007e-05\n",
      "491 3.7569927371805534e-05\n",
      "492 3.667309647426009e-05\n",
      "493 3.5794582800008357e-05\n",
      "494 3.4940730984089896e-05\n",
      "495 3.410521094338037e-05\n",
      "496 3.3291358704445884e-05\n",
      "497 3.249421206419356e-05\n",
      "498 3.171973003190942e-05\n",
      "499 3.09631614072714e-05\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class TwoLayerNet(torch.nn.Module):\n",
    "    def __init__(self, D_in, H, D_out):\n",
    "        \"\"\"\n",
    "        In the constructor we instantiate two nn.Linear modules and assign them as\n",
    "        member variables.\n",
    "        \"\"\"\n",
    "        super(TwoLayerNet, self).__init__()\n",
    "        self.linear1 = torch.nn.Linear(D_in, H)\n",
    "        self.linear2 = torch.nn.Linear(H, D_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        In the forward function we accept a Tensor of input data and we must return\n",
    "        a Tensor of output data. We can use Modules defined in the constructor as\n",
    "        well as arbitrary operators on Tensors.\n",
    "        \"\"\"\n",
    "        h_relu = self.linear1(x).clamp(min=0)\n",
    "        y_pred = self.linear2(h_relu)\n",
    "        return y_pred\n",
    "\n",
    "\n",
    "# N is batch size; D_in is input dimension;\n",
    "# H is hidden dimension; D_out is output dimension.\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "# Create random Tensors to hold inputs and outputs\n",
    "x = torch.randn(N, D_in)\n",
    "y = torch.randn(N, D_out)\n",
    "\n",
    "# Construct our model by instantiating the class defined above\n",
    "model = TwoLayerNet(D_in, H, D_out)\n",
    "\n",
    "# Construct our loss function and an Optimizer. The call to model.parameters()\n",
    "# in the SGD constructor will contain the learnable parameters of the two\n",
    "# nn.Linear modules which are members of the model.\n",
    "criterion = torch.nn.MSELoss(reduction='sum')\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)\n",
    "for t in range(500):\n",
    "    # Forward pass: Compute predicted y by passing x to the model\n",
    "    y_pred = model(x)\n",
    "\n",
    "    # Compute and print loss\n",
    "    loss = criterion(y_pred, y)\n",
    "    print(t, loss.item())\n",
    "\n",
    "    # Zero gradients, perform a backward pass, and update the weights.\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FizzBuzz\n",
    "\n",
    "FizzBuzz是一个简单的小游戏。游戏规则如下：从1开始往上数数，当遇到3的倍数的时候，说fizz，当遇到5的倍数，说buzz，当遇到15的倍数，就说fizzbuzz，其他情况下则正常数数。\n",
    "\n",
    "我们可以写一个简单的小程序来决定要返回正常数值还是fizz, buzz 或者 fizzbuzz。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "buzz\n",
      "fizz\n",
      "fizzbuzz\n"
     ]
    }
   ],
   "source": [
    "# One-hot encode the desired outputs: [number, \"fizz\", \"buzz\", \"fizzbuzz\"]\n",
    "def fizz_buzz_encode(i):\n",
    "    if   i % 15 == 0: return 3\n",
    "    elif i % 5  == 0: return 2\n",
    "    elif i % 3  == 0: return 1\n",
    "    else:             return 0\n",
    "    \n",
    "def fizz_buzz_decode(i, prediction):\n",
    "    return [str(i), \"fizz\", \"buzz\", \"fizzbuzz\"][prediction]\n",
    "\n",
    "print(fizz_buzz_decode(1, fizz_buzz_encode(1)))\n",
    "print(fizz_buzz_decode(2, fizz_buzz_encode(2)))\n",
    "print(fizz_buzz_decode(5, fizz_buzz_encode(5)))\n",
    "print(fizz_buzz_decode(12, fizz_buzz_encode(12)))\n",
    "print(fizz_buzz_decode(15, fizz_buzz_encode(15)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们首先定义模型的输入与输出(训练数据)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "NUM_DIGITS = 10\n",
    "\n",
    "# Represent each input by an array of its binary digits.\n",
    "def binary_encode(i, num_digits):\n",
    "    return np.array([i >> d & 1 for d in range(num_digits)])\n",
    "\n",
    "trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])\n",
    "trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们用PyTorch定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define the model\n",
    "NUM_HIDDEN = 100\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(NUM_HIDDEN, 4)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 为了让我们的模型学会FizzBuzz这个游戏，我们需要定义一个损失函数，和一个优化算法。\n",
    "- 这个优化算法会不断优化（降低）损失函数，使得模型的在该任务上取得尽可能低的损失值。\n",
    "- 损失值低往往表示我们的模型表现好，损失值高表示我们的模型表现差。\n",
    "- 由于FizzBuzz游戏本质上是一个分类问题，我们选用Cross Entropyy Loss函数。\n",
    "- 优化函数我们选用Stochastic Gradient Descent。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下是模型的训练代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0 Loss: 1.1863759756088257\n",
      "Epoch: 1 Loss: 1.157894492149353\n",
      "Epoch: 2 Loss: 1.1498943567276\n",
      "Epoch: 3 Loss: 1.1466518640518188\n",
      "Epoch: 4 Loss: 1.1449916362762451\n",
      "Epoch: 5 Loss: 1.143990159034729\n",
      "Epoch: 6 Loss: 1.143303394317627\n",
      "Epoch: 7 Loss: 1.142789602279663\n",
      "Epoch: 8 Loss: 1.1423760652542114\n",
      "Epoch: 9 Loss: 1.1420289278030396\n",
      "Epoch: 10 Loss: 1.141723394393921\n",
      "Epoch: 11 Loss: 1.141450047492981\n",
      "Epoch: 12 Loss: 1.1412018537521362\n",
      "Epoch: 13 Loss: 1.1409754753112793\n",
      "Epoch: 14 Loss: 1.1407641172409058\n",
      "Epoch: 15 Loss: 1.1405678987503052\n",
      "Epoch: 16 Loss: 1.1403822898864746\n",
      "Epoch: 17 Loss: 1.1402071714401245\n",
      "Epoch: 18 Loss: 1.1400408744812012\n",
      "Epoch: 19 Loss: 1.1398837566375732\n",
      "Epoch: 20 Loss: 1.1397329568862915\n",
      "Epoch: 21 Loss: 1.139588713645935\n",
      "Epoch: 22 Loss: 1.1394506692886353\n",
      "Epoch: 23 Loss: 1.139319896697998\n",
      "Epoch: 24 Loss: 1.1391912698745728\n",
      "Epoch: 25 Loss: 1.1390684843063354\n",
      "Epoch: 26 Loss: 1.1389490365982056\n",
      "Epoch: 27 Loss: 1.138834834098816\n",
      "Epoch: 28 Loss: 1.138723611831665\n",
      "Epoch: 29 Loss: 1.1386148929595947\n",
      "Epoch: 30 Loss: 1.138511061668396\n",
      "Epoch: 31 Loss: 1.1384084224700928\n",
      "Epoch: 32 Loss: 1.1383094787597656\n",
      "Epoch: 33 Loss: 1.138213872909546\n",
      "Epoch: 34 Loss: 1.1381211280822754\n",
      "Epoch: 35 Loss: 1.1380314826965332\n",
      "Epoch: 36 Loss: 1.1379411220550537\n",
      "Epoch: 37 Loss: 1.1378580331802368\n",
      "Epoch: 38 Loss: 1.137773871421814\n",
      "Epoch: 39 Loss: 1.1376910209655762\n",
      "Epoch: 40 Loss: 1.1376110315322876\n",
      "Epoch: 41 Loss: 1.1375325918197632\n",
      "Epoch: 42 Loss: 1.1374561786651611\n",
      "Epoch: 43 Loss: 1.1373814344406128\n",
      "Epoch: 44 Loss: 1.1373075246810913\n",
      "Epoch: 45 Loss: 1.1372358798980713\n",
      "Epoch: 46 Loss: 1.1371639966964722\n",
      "Epoch: 47 Loss: 1.13709557056427\n",
      "Epoch: 48 Loss: 1.1370269060134888\n",
      "Epoch: 49 Loss: 1.136959433555603\n",
      "Epoch: 50 Loss: 1.1368935108184814\n",
      "Epoch: 51 Loss: 1.1368285417556763\n",
      "Epoch: 52 Loss: 1.1367650032043457\n",
      "Epoch: 53 Loss: 1.136702299118042\n",
      "Epoch: 54 Loss: 1.1366392374038696\n",
      "Epoch: 55 Loss: 1.1365766525268555\n",
      "Epoch: 56 Loss: 1.1365140676498413\n",
      "Epoch: 57 Loss: 1.1364532709121704\n",
      "Epoch: 58 Loss: 1.1363914012908936\n",
      "Epoch: 59 Loss: 1.1363309621810913\n",
      "Epoch: 60 Loss: 1.1362704038619995\n",
      "Epoch: 61 Loss: 1.1362107992172241\n",
      "Epoch: 62 Loss: 1.1361552476882935\n",
      "Epoch: 63 Loss: 1.1360974311828613\n",
      "Epoch: 64 Loss: 1.1360419988632202\n",
      "Epoch: 65 Loss: 1.1359858512878418\n",
      "Epoch: 66 Loss: 1.1359294652938843\n",
      "Epoch: 67 Loss: 1.1358730792999268\n",
      "Epoch: 68 Loss: 1.1358181238174438\n",
      "Epoch: 69 Loss: 1.1357614994049072\n",
      "Epoch: 70 Loss: 1.1357064247131348\n",
      "Epoch: 71 Loss: 1.135650396347046\n",
      "Epoch: 72 Loss: 1.1355968713760376\n",
      "Epoch: 73 Loss: 1.135541319847107\n",
      "Epoch: 74 Loss: 1.1354864835739136\n",
      "Epoch: 75 Loss: 1.1354326009750366\n",
      "Epoch: 76 Loss: 1.1353795528411865\n",
      "Epoch: 77 Loss: 1.1353275775909424\n",
      "Epoch: 78 Loss: 1.1352771520614624\n",
      "Epoch: 79 Loss: 1.135224461555481\n",
      "Epoch: 80 Loss: 1.1351745128631592\n",
      "Epoch: 81 Loss: 1.1351231336593628\n",
      "Epoch: 82 Loss: 1.1350746154785156\n",
      "Epoch: 83 Loss: 1.1350243091583252\n",
      "Epoch: 84 Loss: 1.1349738836288452\n",
      "Epoch: 85 Loss: 1.134922742843628\n",
      "Epoch: 86 Loss: 1.1348724365234375\n",
      "Epoch: 87 Loss: 1.134821891784668\n",
      "Epoch: 88 Loss: 1.1347724199295044\n",
      "Epoch: 89 Loss: 1.1347209215164185\n",
      "Epoch: 90 Loss: 1.1346715688705444\n",
      "Epoch: 91 Loss: 1.13462233543396\n",
      "Epoch: 92 Loss: 1.134572148323059\n",
      "Epoch: 93 Loss: 1.1345224380493164\n",
      "Epoch: 94 Loss: 1.1344718933105469\n",
      "Epoch: 95 Loss: 1.134421467781067\n",
      "Epoch: 96 Loss: 1.1343739032745361\n",
      "Epoch: 97 Loss: 1.1343233585357666\n",
      "Epoch: 98 Loss: 1.1342722177505493\n",
      "Epoch: 99 Loss: 1.1342231035232544\n",
      "Epoch: 100 Loss: 1.1341723203659058\n",
      "Epoch: 101 Loss: 1.1341253519058228\n",
      "Epoch: 102 Loss: 1.134071707725525\n",
      "Epoch: 103 Loss: 1.1340221166610718\n",
      "Epoch: 104 Loss: 1.1339722871780396\n",
      "Epoch: 105 Loss: 1.1339224576950073\n",
      "Epoch: 106 Loss: 1.1338707208633423\n",
      "Epoch: 107 Loss: 1.1338176727294922\n",
      "Epoch: 108 Loss: 1.1337658166885376\n",
      "Epoch: 109 Loss: 1.1337124109268188\n",
      "Epoch: 110 Loss: 1.1336588859558105\n",
      "Epoch: 111 Loss: 1.1336060762405396\n",
      "Epoch: 112 Loss: 1.1335557699203491\n",
      "Epoch: 113 Loss: 1.1335015296936035\n",
      "Epoch: 114 Loss: 1.1334484815597534\n",
      "Epoch: 115 Loss: 1.1333973407745361\n",
      "Epoch: 116 Loss: 1.133345603942871\n",
      "Epoch: 117 Loss: 1.133292555809021\n",
      "Epoch: 118 Loss: 1.1332420110702515\n",
      "Epoch: 119 Loss: 1.1331912279129028\n",
      "Epoch: 120 Loss: 1.13313889503479\n",
      "Epoch: 121 Loss: 1.1330863237380981\n",
      "Epoch: 122 Loss: 1.1330300569534302\n",
      "Epoch: 123 Loss: 1.132975459098816\n",
      "Epoch: 124 Loss: 1.1329208612442017\n",
      "Epoch: 125 Loss: 1.13286554813385\n",
      "Epoch: 126 Loss: 1.1328084468841553\n",
      "Epoch: 127 Loss: 1.1327511072158813\n",
      "Epoch: 128 Loss: 1.1326980590820312\n",
      "Epoch: 129 Loss: 1.1326403617858887\n",
      "Epoch: 130 Loss: 1.1325851678848267\n",
      "Epoch: 131 Loss: 1.1325281858444214\n",
      "Epoch: 132 Loss: 1.1324717998504639\n",
      "Epoch: 133 Loss: 1.1324169635772705\n",
      "Epoch: 134 Loss: 1.1323587894439697\n",
      "Epoch: 135 Loss: 1.1323022842407227\n",
      "Epoch: 136 Loss: 1.1322455406188965\n",
      "Epoch: 137 Loss: 1.1321899890899658\n",
      "Epoch: 138 Loss: 1.1321325302124023\n",
      "Epoch: 139 Loss: 1.132074236869812\n",
      "Epoch: 140 Loss: 1.1320151090621948\n",
      "Epoch: 141 Loss: 1.1319571733474731\n",
      "Epoch: 142 Loss: 1.1318986415863037\n",
      "Epoch: 143 Loss: 1.1318415403366089\n",
      "Epoch: 144 Loss: 1.1317806243896484\n",
      "Epoch: 145 Loss: 1.1317188739776611\n",
      "Epoch: 146 Loss: 1.1316568851470947\n",
      "Epoch: 147 Loss: 1.1315945386886597\n",
      "Epoch: 148 Loss: 1.1315333843231201\n",
      "Epoch: 149 Loss: 1.1314669847488403\n",
      "Epoch: 150 Loss: 1.1314020156860352\n",
      "Epoch: 151 Loss: 1.1313365697860718\n",
      "Epoch: 152 Loss: 1.1312707662582397\n",
      "Epoch: 153 Loss: 1.1312015056610107\n",
      "Epoch: 154 Loss: 1.1311382055282593\n",
      "Epoch: 155 Loss: 1.1310704946517944\n",
      "Epoch: 156 Loss: 1.1310042142868042\n",
      "Epoch: 157 Loss: 1.1309386491775513\n",
      "Epoch: 158 Loss: 1.1308705806732178\n",
      "Epoch: 159 Loss: 1.1308040618896484\n",
      "Epoch: 160 Loss: 1.1307406425476074\n",
      "Epoch: 161 Loss: 1.1306703090667725\n",
      "Epoch: 162 Loss: 1.1306068897247314\n",
      "Epoch: 163 Loss: 1.1305382251739502\n",
      "Epoch: 164 Loss: 1.1304738521575928\n",
      "Epoch: 165 Loss: 1.1304092407226562\n",
      "Epoch: 166 Loss: 1.1303397417068481\n",
      "Epoch: 167 Loss: 1.1302763223648071\n",
      "Epoch: 168 Loss: 1.1302064657211304\n",
      "Epoch: 169 Loss: 1.1301416158676147\n",
      "Epoch: 170 Loss: 1.1300736665725708\n",
      "Epoch: 171 Loss: 1.1300044059753418\n",
      "Epoch: 172 Loss: 1.1299372911453247\n",
      "Epoch: 173 Loss: 1.129868745803833\n",
      "Epoch: 174 Loss: 1.1297999620437622\n",
      "Epoch: 175 Loss: 1.1297296285629272\n",
      "Epoch: 176 Loss: 1.1296606063842773\n",
      "Epoch: 177 Loss: 1.129593014717102\n",
      "Epoch: 178 Loss: 1.1295201778411865\n",
      "Epoch: 179 Loss: 1.1294505596160889\n",
      "Epoch: 180 Loss: 1.1293814182281494\n",
      "Epoch: 181 Loss: 1.1293084621429443\n",
      "Epoch: 182 Loss: 1.1292396783828735\n",
      "Epoch: 183 Loss: 1.1291638612747192\n",
      "Epoch: 184 Loss: 1.1290931701660156\n",
      "Epoch: 185 Loss: 1.1290215253829956\n",
      "Epoch: 186 Loss: 1.1289492845535278\n",
      "Epoch: 187 Loss: 1.1288692951202393\n",
      "Epoch: 188 Loss: 1.1287916898727417\n",
      "Epoch: 189 Loss: 1.1287143230438232\n",
      "Epoch: 190 Loss: 1.128640055656433\n",
      "Epoch: 191 Loss: 1.1285561323165894\n",
      "Epoch: 192 Loss: 1.1284788846969604\n",
      "Epoch: 193 Loss: 1.1284000873565674\n",
      "Epoch: 194 Loss: 1.1283186674118042\n",
      "Epoch: 195 Loss: 1.1282389163970947\n",
      "Epoch: 196 Loss: 1.128156065940857\n",
      "Epoch: 197 Loss: 1.128070592880249\n",
      "Epoch: 198 Loss: 1.1279886960983276\n",
      "Epoch: 199 Loss: 1.1279044151306152\n",
      "Epoch: 200 Loss: 1.127820372581482\n",
      "Epoch: 201 Loss: 1.127732515335083\n",
      "Epoch: 202 Loss: 1.1276493072509766\n",
      "Epoch: 203 Loss: 1.1275635957717896\n",
      "Epoch: 204 Loss: 1.1274774074554443\n",
      "Epoch: 205 Loss: 1.1273856163024902\n",
      "Epoch: 206 Loss: 1.1272990703582764\n",
      "Epoch: 207 Loss: 1.127208948135376\n",
      "Epoch: 208 Loss: 1.1271185874938965\n",
      "Epoch: 209 Loss: 1.1270302534103394\n",
      "Epoch: 210 Loss: 1.1269336938858032\n",
      "Epoch: 211 Loss: 1.1268467903137207\n",
      "Epoch: 212 Loss: 1.1267536878585815\n",
      "Epoch: 213 Loss: 1.1266565322875977\n",
      "Epoch: 214 Loss: 1.1265685558319092\n",
      "Epoch: 215 Loss: 1.1264716386795044\n",
      "Epoch: 216 Loss: 1.1263726949691772\n",
      "Epoch: 217 Loss: 1.1262766122817993\n",
      "Epoch: 218 Loss: 1.1261793375015259\n",
      "Epoch: 219 Loss: 1.1260801553726196\n",
      "Epoch: 220 Loss: 1.1259796619415283\n",
      "Epoch: 221 Loss: 1.1258851289749146\n",
      "Epoch: 222 Loss: 1.125781774520874\n",
      "Epoch: 223 Loss: 1.125683307647705\n",
      "Epoch: 224 Loss: 1.1255813837051392\n",
      "Epoch: 225 Loss: 1.1254748106002808\n",
      "Epoch: 226 Loss: 1.1253678798675537\n",
      "Epoch: 227 Loss: 1.1252601146697998\n",
      "Epoch: 228 Loss: 1.125157117843628\n",
      "Epoch: 229 Loss: 1.1250413656234741\n",
      "Epoch: 230 Loss: 1.1249322891235352\n",
      "Epoch: 231 Loss: 1.124815821647644\n",
      "Epoch: 232 Loss: 1.1247016191482544\n",
      "Epoch: 233 Loss: 1.1245849132537842\n",
      "Epoch: 234 Loss: 1.124468207359314\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 235 Loss: 1.1243518590927124\n",
      "Epoch: 236 Loss: 1.1242260932922363\n",
      "Epoch: 237 Loss: 1.1241044998168945\n",
      "Epoch: 238 Loss: 1.123986005783081\n",
      "Epoch: 239 Loss: 1.12385892868042\n",
      "Epoch: 240 Loss: 1.123731255531311\n",
      "Epoch: 241 Loss: 1.1236205101013184\n",
      "Epoch: 242 Loss: 1.123486876487732\n",
      "Epoch: 243 Loss: 1.1233718395233154\n",
      "Epoch: 244 Loss: 1.1232503652572632\n",
      "Epoch: 245 Loss: 1.1231154203414917\n",
      "Epoch: 246 Loss: 1.1229811906814575\n",
      "Epoch: 247 Loss: 1.1228554248809814\n",
      "Epoch: 248 Loss: 1.1227210760116577\n",
      "Epoch: 249 Loss: 1.1225866079330444\n",
      "Epoch: 250 Loss: 1.1224478483200073\n",
      "Epoch: 251 Loss: 1.1223106384277344\n",
      "Epoch: 252 Loss: 1.122177004814148\n",
      "Epoch: 253 Loss: 1.1220247745513916\n",
      "Epoch: 254 Loss: 1.1218791007995605\n",
      "Epoch: 255 Loss: 1.1217379570007324\n",
      "Epoch: 256 Loss: 1.1215908527374268\n",
      "Epoch: 257 Loss: 1.121443271636963\n",
      "Epoch: 258 Loss: 1.121286153793335\n",
      "Epoch: 259 Loss: 1.1211326122283936\n",
      "Epoch: 260 Loss: 1.1209757328033447\n",
      "Epoch: 261 Loss: 1.1208267211914062\n",
      "Epoch: 262 Loss: 1.1206504106521606\n",
      "Epoch: 263 Loss: 1.1204930543899536\n",
      "Epoch: 264 Loss: 1.1203217506408691\n",
      "Epoch: 265 Loss: 1.1201351881027222\n",
      "Epoch: 266 Loss: 1.119951605796814\n",
      "Epoch: 267 Loss: 1.119789719581604\n",
      "Epoch: 268 Loss: 1.11961829662323\n",
      "Epoch: 269 Loss: 1.1194344758987427\n",
      "Epoch: 270 Loss: 1.1192399263381958\n",
      "Epoch: 271 Loss: 1.1190485954284668\n",
      "Epoch: 272 Loss: 1.1188703775405884\n",
      "Epoch: 273 Loss: 1.118676781654358\n",
      "Epoch: 274 Loss: 1.1184593439102173\n",
      "Epoch: 275 Loss: 1.1182655096054077\n",
      "Epoch: 276 Loss: 1.11807382106781\n",
      "Epoch: 277 Loss: 1.1178631782531738\n",
      "Epoch: 278 Loss: 1.1176393032073975\n",
      "Epoch: 279 Loss: 1.117437481880188\n",
      "Epoch: 280 Loss: 1.1172257661819458\n",
      "Epoch: 281 Loss: 1.117003321647644\n",
      "Epoch: 282 Loss: 1.1167978048324585\n",
      "Epoch: 283 Loss: 1.116567611694336\n",
      "Epoch: 284 Loss: 1.1163352727890015\n",
      "Epoch: 285 Loss: 1.1160887479782104\n",
      "Epoch: 286 Loss: 1.1158515214920044\n",
      "Epoch: 287 Loss: 1.1156179904937744\n",
      "Epoch: 288 Loss: 1.1153650283813477\n",
      "Epoch: 289 Loss: 1.1151187419891357\n",
      "Epoch: 290 Loss: 1.1148667335510254\n",
      "Epoch: 291 Loss: 1.1145800352096558\n",
      "Epoch: 292 Loss: 1.11432945728302\n",
      "Epoch: 293 Loss: 1.1140772104263306\n",
      "Epoch: 294 Loss: 1.113797664642334\n",
      "Epoch: 295 Loss: 1.113513708114624\n",
      "Epoch: 296 Loss: 1.1132071018218994\n",
      "Epoch: 297 Loss: 1.1129456758499146\n",
      "Epoch: 298 Loss: 1.1126612424850464\n",
      "Epoch: 299 Loss: 1.112335443496704\n",
      "Epoch: 300 Loss: 1.1120567321777344\n",
      "Epoch: 301 Loss: 1.111741542816162\n",
      "Epoch: 302 Loss: 1.1114040613174438\n",
      "Epoch: 303 Loss: 1.1110570430755615\n",
      "Epoch: 304 Loss: 1.1107770204544067\n",
      "Epoch: 305 Loss: 1.1104018688201904\n",
      "Epoch: 306 Loss: 1.1101232767105103\n",
      "Epoch: 307 Loss: 1.1097376346588135\n",
      "Epoch: 308 Loss: 1.1094231605529785\n",
      "Epoch: 309 Loss: 1.1090389490127563\n",
      "Epoch: 310 Loss: 1.1086326837539673\n",
      "Epoch: 311 Loss: 1.1082741022109985\n",
      "Epoch: 312 Loss: 1.1080201864242554\n",
      "Epoch: 313 Loss: 1.1075631380081177\n",
      "Epoch: 314 Loss: 1.1071727275848389\n",
      "Epoch: 315 Loss: 1.1067951917648315\n",
      "Epoch: 316 Loss: 1.1064176559448242\n",
      "Epoch: 317 Loss: 1.1060806512832642\n",
      "Epoch: 318 Loss: 1.1057096719741821\n",
      "Epoch: 319 Loss: 1.1053379774093628\n",
      "Epoch: 320 Loss: 1.1049891710281372\n",
      "Epoch: 321 Loss: 1.1045098304748535\n",
      "Epoch: 322 Loss: 1.1042743921279907\n",
      "Epoch: 323 Loss: 1.1038472652435303\n",
      "Epoch: 324 Loss: 1.1032615900039673\n",
      "Epoch: 325 Loss: 1.1029300689697266\n",
      "Epoch: 326 Loss: 1.1025500297546387\n",
      "Epoch: 327 Loss: 1.102237582206726\n",
      "Epoch: 328 Loss: 1.101718783378601\n",
      "Epoch: 329 Loss: 1.1016079187393188\n",
      "Epoch: 330 Loss: 1.1011749505996704\n",
      "Epoch: 331 Loss: 1.1008780002593994\n",
      "Epoch: 332 Loss: 1.1005438566207886\n",
      "Epoch: 333 Loss: 1.0998064279556274\n",
      "Epoch: 334 Loss: 1.0991853475570679\n",
      "Epoch: 335 Loss: 1.0987621545791626\n",
      "Epoch: 336 Loss: 1.098401427268982\n",
      "Epoch: 337 Loss: 1.0979806184768677\n",
      "Epoch: 338 Loss: 1.097618579864502\n",
      "Epoch: 339 Loss: 1.097230076789856\n",
      "Epoch: 340 Loss: 1.096835732460022\n",
      "Epoch: 341 Loss: 1.096483826637268\n",
      "Epoch: 342 Loss: 1.0960729122161865\n",
      "Epoch: 343 Loss: 1.095646619796753\n",
      "Epoch: 344 Loss: 1.0952045917510986\n",
      "Epoch: 345 Loss: 1.0947699546813965\n",
      "Epoch: 346 Loss: 1.094376802444458\n",
      "Epoch: 347 Loss: 1.0938690900802612\n",
      "Epoch: 348 Loss: 1.0935477018356323\n",
      "Epoch: 349 Loss: 1.0931107997894287\n",
      "Epoch: 350 Loss: 1.0926800966262817\n",
      "Epoch: 351 Loss: 1.0921432971954346\n",
      "Epoch: 352 Loss: 1.0916142463684082\n",
      "Epoch: 353 Loss: 1.0913349390029907\n",
      "Epoch: 354 Loss: 1.0908869504928589\n",
      "Epoch: 355 Loss: 1.0903311967849731\n",
      "Epoch: 356 Loss: 1.0899087190628052\n",
      "Epoch: 357 Loss: 1.0895369052886963\n",
      "Epoch: 358 Loss: 1.0890287160873413\n",
      "Epoch: 359 Loss: 1.088529348373413\n",
      "Epoch: 360 Loss: 1.0880995988845825\n",
      "Epoch: 361 Loss: 1.087646245956421\n",
      "Epoch: 362 Loss: 1.0871639251708984\n",
      "Epoch: 363 Loss: 1.0866609811782837\n",
      "Epoch: 364 Loss: 1.0861918926239014\n",
      "Epoch: 365 Loss: 1.085727572441101\n",
      "Epoch: 366 Loss: 1.0852694511413574\n",
      "Epoch: 367 Loss: 1.0847870111465454\n",
      "Epoch: 368 Loss: 1.0842137336730957\n",
      "Epoch: 369 Loss: 1.0838044881820679\n",
      "Epoch: 370 Loss: 1.0832666158676147\n",
      "Epoch: 371 Loss: 1.0827584266662598\n",
      "Epoch: 372 Loss: 1.0823423862457275\n",
      "Epoch: 373 Loss: 1.0819532871246338\n",
      "Epoch: 374 Loss: 1.081296682357788\n",
      "Epoch: 375 Loss: 1.080867052078247\n",
      "Epoch: 376 Loss: 1.0804170370101929\n",
      "Epoch: 377 Loss: 1.0798572301864624\n",
      "Epoch: 378 Loss: 1.0794222354888916\n",
      "Epoch: 379 Loss: 1.0789073705673218\n",
      "Epoch: 380 Loss: 1.0781735181808472\n",
      "Epoch: 381 Loss: 1.0778888463974\n",
      "Epoch: 382 Loss: 1.077265977859497\n",
      "Epoch: 383 Loss: 1.076802134513855\n",
      "Epoch: 384 Loss: 1.0763359069824219\n",
      "Epoch: 385 Loss: 1.075731635093689\n",
      "Epoch: 386 Loss: 1.0752193927764893\n",
      "Epoch: 387 Loss: 1.0747054815292358\n",
      "Epoch: 388 Loss: 1.0740560293197632\n",
      "Epoch: 389 Loss: 1.0735050439834595\n",
      "Epoch: 390 Loss: 1.0729568004608154\n",
      "Epoch: 391 Loss: 1.072386622428894\n",
      "Epoch: 392 Loss: 1.0719144344329834\n",
      "Epoch: 393 Loss: 1.0711684226989746\n",
      "Epoch: 394 Loss: 1.0706907510757446\n",
      "Epoch: 395 Loss: 1.0703145265579224\n",
      "Epoch: 396 Loss: 1.0697311162948608\n",
      "Epoch: 397 Loss: 1.0691791772842407\n",
      "Epoch: 398 Loss: 1.068681240081787\n",
      "Epoch: 399 Loss: 1.0682350397109985\n",
      "Epoch: 400 Loss: 1.0675621032714844\n",
      "Epoch: 401 Loss: 1.0671271085739136\n",
      "Epoch: 402 Loss: 1.0662921667099\n",
      "Epoch: 403 Loss: 1.065736174583435\n",
      "Epoch: 404 Loss: 1.0654525756835938\n",
      "Epoch: 405 Loss: 1.0647729635238647\n",
      "Epoch: 406 Loss: 1.0644084215164185\n",
      "Epoch: 407 Loss: 1.0635782480239868\n",
      "Epoch: 408 Loss: 1.0631297826766968\n",
      "Epoch: 409 Loss: 1.0623339414596558\n",
      "Epoch: 410 Loss: 1.061985969543457\n",
      "Epoch: 411 Loss: 1.0613348484039307\n",
      "Epoch: 412 Loss: 1.060482382774353\n",
      "Epoch: 413 Loss: 1.060004472732544\n",
      "Epoch: 414 Loss: 1.059225082397461\n",
      "Epoch: 415 Loss: 1.0588304996490479\n",
      "Epoch: 416 Loss: 1.0580798387527466\n",
      "Epoch: 417 Loss: 1.0577386617660522\n",
      "Epoch: 418 Loss: 1.0570645332336426\n",
      "Epoch: 419 Loss: 1.0564700365066528\n",
      "Epoch: 420 Loss: 1.0557035207748413\n",
      "Epoch: 421 Loss: 1.0552617311477661\n",
      "Epoch: 422 Loss: 1.0544142723083496\n",
      "Epoch: 423 Loss: 1.0539783239364624\n",
      "Epoch: 424 Loss: 1.0533838272094727\n",
      "Epoch: 425 Loss: 1.0526970624923706\n",
      "Epoch: 426 Loss: 1.052010416984558\n",
      "Epoch: 427 Loss: 1.0512663125991821\n",
      "Epoch: 428 Loss: 1.0507055521011353\n",
      "Epoch: 429 Loss: 1.050052285194397\n",
      "Epoch: 430 Loss: 1.0496389865875244\n",
      "Epoch: 431 Loss: 1.0486656427383423\n",
      "Epoch: 432 Loss: 1.0483506917953491\n",
      "Epoch: 433 Loss: 1.0477838516235352\n",
      "Epoch: 434 Loss: 1.0466361045837402\n",
      "Epoch: 435 Loss: 1.0464274883270264\n",
      "Epoch: 436 Loss: 1.0456762313842773\n",
      "Epoch: 437 Loss: 1.0453243255615234\n",
      "Epoch: 438 Loss: 1.0443986654281616\n",
      "Epoch: 439 Loss: 1.0436006784439087\n",
      "Epoch: 440 Loss: 1.0433143377304077\n",
      "Epoch: 441 Loss: 1.0427567958831787\n",
      "Epoch: 442 Loss: 1.0419714450836182\n",
      "Epoch: 443 Loss: 1.040967345237732\n",
      "Epoch: 444 Loss: 1.040831208229065\n",
      "Epoch: 445 Loss: 1.0401309728622437\n",
      "Epoch: 446 Loss: 1.0396616458892822\n",
      "Epoch: 447 Loss: 1.0387145280838013\n",
      "Epoch: 448 Loss: 1.0377854108810425\n",
      "Epoch: 449 Loss: 1.0373036861419678\n",
      "Epoch: 450 Loss: 1.036655306816101\n",
      "Epoch: 451 Loss: 1.0358779430389404\n",
      "Epoch: 452 Loss: 1.0351766347885132\n",
      "Epoch: 453 Loss: 1.0345991849899292\n",
      "Epoch: 454 Loss: 1.0337433815002441\n",
      "Epoch: 455 Loss: 1.0333425998687744\n",
      "Epoch: 456 Loss: 1.0328092575073242\n",
      "Epoch: 457 Loss: 1.031876564025879\n",
      "Epoch: 458 Loss: 1.0312937498092651\n",
      "Epoch: 459 Loss: 1.0303008556365967\n",
      "Epoch: 460 Loss: 1.0301439762115479\n",
      "Epoch: 461 Loss: 1.02925705909729\n",
      "Epoch: 462 Loss: 1.0283920764923096\n",
      "Epoch: 463 Loss: 1.0278183221817017\n",
      "Epoch: 464 Loss: 1.0273000001907349\n",
      "Epoch: 465 Loss: 1.0265734195709229\n",
      "Epoch: 466 Loss: 1.026113510131836\n",
      "Epoch: 467 Loss: 1.0248644351959229\n",
      "Epoch: 468 Loss: 1.0245485305786133\n",
      "Epoch: 469 Loss: 1.0237386226654053\n",
      "Epoch: 470 Loss: 1.0230474472045898\n",
      "Epoch: 471 Loss: 1.02191960811615\n",
      "Epoch: 472 Loss: 1.0211703777313232\n",
      "Epoch: 473 Loss: 1.0206960439682007\n",
      "Epoch: 474 Loss: 1.0198324918746948\n",
      "Epoch: 475 Loss: 1.0189540386199951\n",
      "Epoch: 476 Loss: 1.0181984901428223\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 477 Loss: 1.0169470310211182\n",
      "Epoch: 478 Loss: 1.0164058208465576\n",
      "Epoch: 479 Loss: 1.0156841278076172\n",
      "Epoch: 480 Loss: 1.015245795249939\n",
      "Epoch: 481 Loss: 1.0140904188156128\n",
      "Epoch: 482 Loss: 1.0138357877731323\n",
      "Epoch: 483 Loss: 1.0129069089889526\n",
      "Epoch: 484 Loss: 1.0119367837905884\n",
      "Epoch: 485 Loss: 1.0111967325210571\n",
      "Epoch: 486 Loss: 1.0102176666259766\n",
      "Epoch: 487 Loss: 1.009842038154602\n",
      "Epoch: 488 Loss: 1.0088571310043335\n",
      "Epoch: 489 Loss: 1.0086297988891602\n",
      "Epoch: 490 Loss: 1.0071672201156616\n",
      "Epoch: 491 Loss: 1.0068907737731934\n",
      "Epoch: 492 Loss: 1.0057222843170166\n",
      "Epoch: 493 Loss: 1.0049949884414673\n",
      "Epoch: 494 Loss: 1.0045706033706665\n",
      "Epoch: 495 Loss: 1.0041097402572632\n",
      "Epoch: 496 Loss: 1.0033283233642578\n",
      "Epoch: 497 Loss: 1.0022178888320923\n",
      "Epoch: 498 Loss: 1.0014426708221436\n",
      "Epoch: 499 Loss: 1.0003271102905273\n",
      "Epoch: 500 Loss: 1.000455379486084\n",
      "Epoch: 501 Loss: 0.9996194243431091\n",
      "Epoch: 502 Loss: 0.9984729290008545\n",
      "Epoch: 503 Loss: 1.0009974241256714\n",
      "Epoch: 504 Loss: 0.9975752830505371\n",
      "Epoch: 505 Loss: 0.9993940591812134\n",
      "Epoch: 506 Loss: 0.9957968592643738\n",
      "Epoch: 507 Loss: 0.9969072937965393\n",
      "Epoch: 508 Loss: 0.9949268102645874\n",
      "Epoch: 509 Loss: 0.9951923489570618\n",
      "Epoch: 510 Loss: 0.9936690330505371\n",
      "Epoch: 511 Loss: 0.9931305646896362\n",
      "Epoch: 512 Loss: 0.9917634129524231\n",
      "Epoch: 513 Loss: 0.9910022616386414\n",
      "Epoch: 514 Loss: 0.990297794342041\n",
      "Epoch: 515 Loss: 0.9898818731307983\n",
      "Epoch: 516 Loss: 0.9883328080177307\n",
      "Epoch: 517 Loss: 0.9885256290435791\n",
      "Epoch: 518 Loss: 0.9876739978790283\n",
      "Epoch: 519 Loss: 0.9870556592941284\n",
      "Epoch: 520 Loss: 0.9857414960861206\n",
      "Epoch: 521 Loss: 0.985196590423584\n",
      "Epoch: 522 Loss: 0.984170138835907\n",
      "Epoch: 523 Loss: 0.983726978302002\n",
      "Epoch: 524 Loss: 0.9834291338920593\n",
      "Epoch: 525 Loss: 0.981545090675354\n",
      "Epoch: 526 Loss: 0.9820225834846497\n",
      "Epoch: 527 Loss: 0.9805842638015747\n",
      "Epoch: 528 Loss: 0.9792588353157043\n",
      "Epoch: 529 Loss: 0.9790155291557312\n",
      "Epoch: 530 Loss: 0.9780834913253784\n",
      "Epoch: 531 Loss: 0.9770175814628601\n",
      "Epoch: 532 Loss: 0.9758765697479248\n",
      "Epoch: 533 Loss: 0.976371169090271\n",
      "Epoch: 534 Loss: 0.9758875370025635\n",
      "Epoch: 535 Loss: 0.9735588431358337\n",
      "Epoch: 536 Loss: 0.974554181098938\n",
      "Epoch: 537 Loss: 0.972383975982666\n",
      "Epoch: 538 Loss: 0.9723476767539978\n",
      "Epoch: 539 Loss: 0.9708409905433655\n",
      "Epoch: 540 Loss: 0.9695843458175659\n",
      "Epoch: 541 Loss: 0.9714033603668213\n",
      "Epoch: 542 Loss: 0.9684605002403259\n",
      "Epoch: 543 Loss: 0.9673042893409729\n",
      "Epoch: 544 Loss: 0.9658445715904236\n",
      "Epoch: 545 Loss: 0.9636393785476685\n",
      "Epoch: 546 Loss: 0.9677884578704834\n",
      "Epoch: 547 Loss: 0.9644144773483276\n",
      "Epoch: 548 Loss: 0.9645833373069763\n",
      "Epoch: 549 Loss: 0.9619366526603699\n",
      "Epoch: 550 Loss: 0.9624378085136414\n",
      "Epoch: 551 Loss: 0.9618809819221497\n",
      "Epoch: 552 Loss: 0.9590139985084534\n",
      "Epoch: 553 Loss: 0.9598402976989746\n",
      "Epoch: 554 Loss: 0.9588460326194763\n",
      "Epoch: 555 Loss: 0.9574756622314453\n",
      "Epoch: 556 Loss: 0.9565178155899048\n",
      "Epoch: 557 Loss: 0.9562681913375854\n",
      "Epoch: 558 Loss: 0.9539746642112732\n",
      "Epoch: 559 Loss: 0.9545136094093323\n",
      "Epoch: 560 Loss: 0.9554665088653564\n",
      "Epoch: 561 Loss: 0.9514045119285583\n",
      "Epoch: 562 Loss: 0.949479341506958\n",
      "Epoch: 563 Loss: 0.9507191777229309\n",
      "Epoch: 564 Loss: 0.9500700831413269\n",
      "Epoch: 565 Loss: 0.9500521421432495\n",
      "Epoch: 566 Loss: 0.9492478370666504\n",
      "Epoch: 567 Loss: 0.9474289417266846\n",
      "Epoch: 568 Loss: 0.9450880885124207\n",
      "Epoch: 569 Loss: 0.944922685623169\n",
      "Epoch: 570 Loss: 0.945354163646698\n",
      "Epoch: 571 Loss: 0.9461016654968262\n",
      "Epoch: 572 Loss: 0.9424131512641907\n",
      "Epoch: 573 Loss: 0.941405713558197\n",
      "Epoch: 574 Loss: 0.9418626427650452\n",
      "Epoch: 575 Loss: 0.9396799206733704\n",
      "Epoch: 576 Loss: 0.9391783475875854\n",
      "Epoch: 577 Loss: 0.9392109513282776\n",
      "Epoch: 578 Loss: 0.9375864863395691\n",
      "Epoch: 579 Loss: 0.9370399117469788\n",
      "Epoch: 580 Loss: 0.937750518321991\n",
      "Epoch: 581 Loss: 0.9365839958190918\n",
      "Epoch: 582 Loss: 0.9335255026817322\n",
      "Epoch: 583 Loss: 0.9319561123847961\n",
      "Epoch: 584 Loss: 0.9346998333930969\n",
      "Epoch: 585 Loss: 0.9314708709716797\n",
      "Epoch: 586 Loss: 0.933095395565033\n",
      "Epoch: 587 Loss: 0.9283064007759094\n",
      "Epoch: 588 Loss: 0.9296017289161682\n",
      "Epoch: 589 Loss: 0.9292718768119812\n",
      "Epoch: 590 Loss: 0.9264757633209229\n",
      "Epoch: 591 Loss: 0.9263455271720886\n",
      "Epoch: 592 Loss: 0.9273872375488281\n",
      "Epoch: 593 Loss: 0.9238070845603943\n",
      "Epoch: 594 Loss: 0.922609269618988\n",
      "Epoch: 595 Loss: 0.923794150352478\n",
      "Epoch: 596 Loss: 0.9208877086639404\n",
      "Epoch: 597 Loss: 0.9213778972625732\n",
      "Epoch: 598 Loss: 0.920477569103241\n",
      "Epoch: 599 Loss: 0.9183984994888306\n",
      "Epoch: 600 Loss: 0.9171709418296814\n",
      "Epoch: 601 Loss: 0.9176185131072998\n",
      "Epoch: 602 Loss: 0.9160946011543274\n",
      "Epoch: 603 Loss: 0.915505588054657\n",
      "Epoch: 604 Loss: 0.9173715710639954\n",
      "Epoch: 605 Loss: 0.912653923034668\n",
      "Epoch: 606 Loss: 0.9120131731033325\n",
      "Epoch: 607 Loss: 0.9110011458396912\n",
      "Epoch: 608 Loss: 0.9091194868087769\n",
      "Epoch: 609 Loss: 0.9117856025695801\n",
      "Epoch: 610 Loss: 0.9067298769950867\n",
      "Epoch: 611 Loss: 0.9090607762336731\n",
      "Epoch: 612 Loss: 0.9054461121559143\n",
      "Epoch: 613 Loss: 0.9063969850540161\n",
      "Epoch: 614 Loss: 0.9042546153068542\n",
      "Epoch: 615 Loss: 0.9027125835418701\n",
      "Epoch: 616 Loss: 0.9035295844078064\n",
      "Epoch: 617 Loss: 0.9019988179206848\n",
      "Epoch: 618 Loss: 0.8999831676483154\n",
      "Epoch: 619 Loss: 0.8992998600006104\n",
      "Epoch: 620 Loss: 0.8998374342918396\n",
      "Epoch: 621 Loss: 0.8983879685401917\n",
      "Epoch: 622 Loss: 0.8993449211120605\n",
      "Epoch: 623 Loss: 0.8953387141227722\n",
      "Epoch: 624 Loss: 0.8950462341308594\n",
      "Epoch: 625 Loss: 0.8917537331581116\n",
      "Epoch: 626 Loss: 0.8946303129196167\n",
      "Epoch: 627 Loss: 0.8929725885391235\n",
      "Epoch: 628 Loss: 0.8916193246841431\n",
      "Epoch: 629 Loss: 0.8901514410972595\n",
      "Epoch: 630 Loss: 0.8892200589179993\n",
      "Epoch: 631 Loss: 0.8889347910881042\n",
      "Epoch: 632 Loss: 0.8877301216125488\n",
      "Epoch: 633 Loss: 0.8847933411598206\n",
      "Epoch: 634 Loss: 0.8852814435958862\n",
      "Epoch: 635 Loss: 0.8855648636817932\n",
      "Epoch: 636 Loss: 0.8823348879814148\n",
      "Epoch: 637 Loss: 0.8816485404968262\n",
      "Epoch: 638 Loss: 0.883183479309082\n",
      "Epoch: 639 Loss: 0.8803566694259644\n",
      "Epoch: 640 Loss: 0.876807689666748\n",
      "Epoch: 641 Loss: 0.8785204887390137\n",
      "Epoch: 642 Loss: 0.8789061903953552\n",
      "Epoch: 643 Loss: 0.8762271404266357\n",
      "Epoch: 644 Loss: 0.8749611973762512\n",
      "Epoch: 645 Loss: 0.8726414442062378\n",
      "Epoch: 646 Loss: 0.873367190361023\n",
      "Epoch: 647 Loss: 0.8713034987449646\n",
      "Epoch: 648 Loss: 0.870990514755249\n",
      "Epoch: 649 Loss: 0.8702021837234497\n",
      "Epoch: 650 Loss: 0.8672930002212524\n",
      "Epoch: 651 Loss: 0.8679010272026062\n",
      "Epoch: 652 Loss: 0.8675572872161865\n",
      "Epoch: 653 Loss: 0.8675309419631958\n",
      "Epoch: 654 Loss: 0.8633657097816467\n",
      "Epoch: 655 Loss: 0.8653574585914612\n",
      "Epoch: 656 Loss: 0.8627119660377502\n",
      "Epoch: 657 Loss: 0.8622024655342102\n",
      "Epoch: 658 Loss: 0.859891951084137\n",
      "Epoch: 659 Loss: 0.8577309250831604\n",
      "Epoch: 660 Loss: 0.8579493761062622\n",
      "Epoch: 661 Loss: 0.8582304120063782\n",
      "Epoch: 662 Loss: 0.8564459681510925\n",
      "Epoch: 663 Loss: 0.8572221994400024\n",
      "Epoch: 664 Loss: 0.8536855578422546\n",
      "Epoch: 665 Loss: 0.856471061706543\n",
      "Epoch: 666 Loss: 0.8508054614067078\n",
      "Epoch: 667 Loss: 0.8505745530128479\n",
      "Epoch: 668 Loss: 0.850252628326416\n",
      "Epoch: 669 Loss: 0.8491725325584412\n",
      "Epoch: 670 Loss: 0.8481160402297974\n",
      "Epoch: 671 Loss: 0.8485754728317261\n",
      "Epoch: 672 Loss: 0.8460758924484253\n",
      "Epoch: 673 Loss: 0.8485752940177917\n",
      "Epoch: 674 Loss: 0.8420103192329407\n",
      "Epoch: 675 Loss: 0.8437463045120239\n",
      "Epoch: 676 Loss: 0.8419939875602722\n",
      "Epoch: 677 Loss: 0.8422205448150635\n",
      "Epoch: 678 Loss: 0.838343620300293\n",
      "Epoch: 679 Loss: 0.8381603360176086\n",
      "Epoch: 680 Loss: 0.8377304077148438\n",
      "Epoch: 681 Loss: 0.8362454771995544\n",
      "Epoch: 682 Loss: 0.8344930410385132\n",
      "Epoch: 683 Loss: 0.8331457376480103\n",
      "Epoch: 684 Loss: 0.8332890868186951\n",
      "Epoch: 685 Loss: 0.8313658237457275\n",
      "Epoch: 686 Loss: 0.8342143893241882\n",
      "Epoch: 687 Loss: 0.8277317881584167\n",
      "Epoch: 688 Loss: 0.8303632736206055\n",
      "Epoch: 689 Loss: 0.8258471488952637\n",
      "Epoch: 690 Loss: 0.829867422580719\n",
      "Epoch: 691 Loss: 0.8254421353340149\n",
      "Epoch: 692 Loss: 0.8246580362319946\n",
      "Epoch: 693 Loss: 0.8232406377792358\n",
      "Epoch: 694 Loss: 0.8237690925598145\n",
      "Epoch: 695 Loss: 0.8197668194770813\n",
      "Epoch: 696 Loss: 0.8228768110275269\n",
      "Epoch: 697 Loss: 0.8188248872756958\n",
      "Epoch: 698 Loss: 0.8169060349464417\n",
      "Epoch: 699 Loss: 0.8170427680015564\n",
      "Epoch: 700 Loss: 0.8172638416290283\n",
      "Epoch: 701 Loss: 0.8145975470542908\n",
      "Epoch: 702 Loss: 0.8155110478401184\n",
      "Epoch: 703 Loss: 0.8131051063537598\n",
      "Epoch: 704 Loss: 0.8118221163749695\n",
      "Epoch: 705 Loss: 0.8127391338348389\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 706 Loss: 0.8089808821678162\n",
      "Epoch: 707 Loss: 0.8104071617126465\n",
      "Epoch: 708 Loss: 0.8067026138305664\n",
      "Epoch: 709 Loss: 0.8088684678077698\n",
      "Epoch: 710 Loss: 0.8047152757644653\n",
      "Epoch: 711 Loss: 0.8059985041618347\n",
      "Epoch: 712 Loss: 0.8031206727027893\n",
      "Epoch: 713 Loss: 0.8029946684837341\n",
      "Epoch: 714 Loss: 0.800683319568634\n",
      "Epoch: 715 Loss: 0.8016392588615417\n",
      "Epoch: 716 Loss: 0.7986423373222351\n",
      "Epoch: 717 Loss: 0.7988008856773376\n",
      "Epoch: 718 Loss: 0.7977583408355713\n",
      "Epoch: 719 Loss: 0.7947818040847778\n",
      "Epoch: 720 Loss: 0.7966616153717041\n",
      "Epoch: 721 Loss: 0.7928966283798218\n",
      "Epoch: 722 Loss: 0.7956968545913696\n",
      "Epoch: 723 Loss: 0.7919869422912598\n",
      "Epoch: 724 Loss: 0.7916623950004578\n",
      "Epoch: 725 Loss: 0.7900850772857666\n",
      "Epoch: 726 Loss: 0.7895097732543945\n",
      "Epoch: 727 Loss: 0.7883893847465515\n",
      "Epoch: 728 Loss: 0.7864912748336792\n",
      "Epoch: 729 Loss: 0.7869855165481567\n",
      "Epoch: 730 Loss: 0.7848015427589417\n",
      "Epoch: 731 Loss: 0.7843725085258484\n",
      "Epoch: 732 Loss: 0.7833723425865173\n",
      "Epoch: 733 Loss: 0.7817403078079224\n",
      "Epoch: 734 Loss: 0.7820266485214233\n",
      "Epoch: 735 Loss: 0.77904212474823\n",
      "Epoch: 736 Loss: 0.7793537378311157\n",
      "Epoch: 737 Loss: 0.7776186466217041\n",
      "Epoch: 738 Loss: 0.7768265604972839\n",
      "Epoch: 739 Loss: 0.7753992676734924\n",
      "Epoch: 740 Loss: 0.7770681977272034\n",
      "Epoch: 741 Loss: 0.773063063621521\n",
      "Epoch: 742 Loss: 0.7732106447219849\n",
      "Epoch: 743 Loss: 0.7710105180740356\n",
      "Epoch: 744 Loss: 0.7717345356941223\n",
      "Epoch: 745 Loss: 0.7689521312713623\n",
      "Epoch: 746 Loss: 0.7703216075897217\n",
      "Epoch: 747 Loss: 0.766780436038971\n",
      "Epoch: 748 Loss: 0.7671802639961243\n",
      "Epoch: 749 Loss: 0.7651764154434204\n",
      "Epoch: 750 Loss: 0.7663453817367554\n",
      "Epoch: 751 Loss: 0.7628892064094543\n",
      "Epoch: 752 Loss: 0.7636300921440125\n",
      "Epoch: 753 Loss: 0.761029064655304\n",
      "Epoch: 754 Loss: 0.7611857652664185\n",
      "Epoch: 755 Loss: 0.7587876915931702\n",
      "Epoch: 756 Loss: 0.7595551609992981\n",
      "Epoch: 757 Loss: 0.7566842436790466\n",
      "Epoch: 758 Loss: 0.7567640542984009\n",
      "Epoch: 759 Loss: 0.755727231502533\n",
      "Epoch: 760 Loss: 0.754163920879364\n",
      "Epoch: 761 Loss: 0.7545925974845886\n",
      "Epoch: 762 Loss: 0.7518190145492554\n",
      "Epoch: 763 Loss: 0.7527292370796204\n",
      "Epoch: 764 Loss: 0.7494574785232544\n",
      "Epoch: 765 Loss: 0.7501950263977051\n",
      "Epoch: 766 Loss: 0.7472171783447266\n",
      "Epoch: 767 Loss: 0.7477479577064514\n",
      "Epoch: 768 Loss: 0.7457408308982849\n",
      "Epoch: 769 Loss: 0.7448764443397522\n",
      "Epoch: 770 Loss: 0.7445037961006165\n",
      "Epoch: 771 Loss: 0.7437456250190735\n",
      "Epoch: 772 Loss: 0.7414736747741699\n",
      "Epoch: 773 Loss: 0.7415558695793152\n",
      "Epoch: 774 Loss: 0.7398374676704407\n",
      "Epoch: 775 Loss: 0.7405129075050354\n",
      "Epoch: 776 Loss: 0.7373217344284058\n",
      "Epoch: 777 Loss: 0.737241268157959\n",
      "Epoch: 778 Loss: 0.7354251146316528\n",
      "Epoch: 779 Loss: 0.7362883687019348\n",
      "Epoch: 780 Loss: 0.7333967685699463\n",
      "Epoch: 781 Loss: 0.7335584759712219\n",
      "Epoch: 782 Loss: 0.7312490344047546\n",
      "Epoch: 783 Loss: 0.7322400212287903\n",
      "Epoch: 784 Loss: 0.7292697429656982\n",
      "Epoch: 785 Loss: 0.7299661636352539\n",
      "Epoch: 786 Loss: 0.7271506786346436\n",
      "Epoch: 787 Loss: 0.7273697853088379\n",
      "Epoch: 788 Loss: 0.7251990437507629\n",
      "Epoch: 789 Loss: 0.7249505519866943\n",
      "Epoch: 790 Loss: 0.7230759859085083\n",
      "Epoch: 791 Loss: 0.7232457995414734\n",
      "Epoch: 792 Loss: 0.7207052707672119\n",
      "Epoch: 793 Loss: 0.7210170030593872\n",
      "Epoch: 794 Loss: 0.7189929485321045\n",
      "Epoch: 795 Loss: 0.7205104231834412\n",
      "Epoch: 796 Loss: 0.7167028188705444\n",
      "Epoch: 797 Loss: 0.7172025442123413\n",
      "Epoch: 798 Loss: 0.714631974697113\n",
      "Epoch: 799 Loss: 0.7154566645622253\n",
      "Epoch: 800 Loss: 0.7126117944717407\n",
      "Epoch: 801 Loss: 0.7131263017654419\n",
      "Epoch: 802 Loss: 0.7106145620346069\n",
      "Epoch: 803 Loss: 0.7103406190872192\n",
      "Epoch: 804 Loss: 0.7087877988815308\n",
      "Epoch: 805 Loss: 0.7089645862579346\n",
      "Epoch: 806 Loss: 0.7064817547798157\n",
      "Epoch: 807 Loss: 0.7067750096321106\n",
      "Epoch: 808 Loss: 0.7043188810348511\n",
      "Epoch: 809 Loss: 0.7048957943916321\n",
      "Epoch: 810 Loss: 0.7021744847297668\n",
      "Epoch: 811 Loss: 0.7029871940612793\n",
      "Epoch: 812 Loss: 0.6999280452728271\n",
      "Epoch: 813 Loss: 0.7010262608528137\n",
      "Epoch: 814 Loss: 0.6981289982795715\n",
      "Epoch: 815 Loss: 0.698557436466217\n",
      "Epoch: 816 Loss: 0.6957795023918152\n",
      "Epoch: 817 Loss: 0.6966675519943237\n",
      "Epoch: 818 Loss: 0.6938087344169617\n",
      "Epoch: 819 Loss: 0.6948986649513245\n",
      "Epoch: 820 Loss: 0.6919453740119934\n",
      "Epoch: 821 Loss: 0.6926137804985046\n",
      "Epoch: 822 Loss: 0.6895492076873779\n",
      "Epoch: 823 Loss: 0.6910618543624878\n",
      "Epoch: 824 Loss: 0.6874080300331116\n",
      "Epoch: 825 Loss: 0.6884058713912964\n",
      "Epoch: 826 Loss: 0.685306191444397\n",
      "Epoch: 827 Loss: 0.6864284873008728\n",
      "Epoch: 828 Loss: 0.6834753155708313\n",
      "Epoch: 829 Loss: 0.6840124726295471\n",
      "Epoch: 830 Loss: 0.6814344525337219\n",
      "Epoch: 831 Loss: 0.6813713312149048\n",
      "Epoch: 832 Loss: 0.6796747446060181\n",
      "Epoch: 833 Loss: 0.6792466044425964\n",
      "Epoch: 834 Loss: 0.6789682507514954\n",
      "Epoch: 835 Loss: 0.6762639284133911\n",
      "Epoch: 836 Loss: 0.6773759126663208\n",
      "Epoch: 837 Loss: 0.6747724413871765\n",
      "Epoch: 838 Loss: 0.67513108253479\n",
      "Epoch: 839 Loss: 0.6719778776168823\n",
      "Epoch: 840 Loss: 0.673020601272583\n",
      "Epoch: 841 Loss: 0.6699268817901611\n",
      "Epoch: 842 Loss: 0.6712299585342407\n",
      "Epoch: 843 Loss: 0.6682762503623962\n",
      "Epoch: 844 Loss: 0.6686745882034302\n",
      "Epoch: 845 Loss: 0.6659997701644897\n",
      "Epoch: 846 Loss: 0.6667563319206238\n",
      "Epoch: 847 Loss: 0.6643199920654297\n",
      "Epoch: 848 Loss: 0.6649546027183533\n",
      "Epoch: 849 Loss: 0.661919891834259\n",
      "Epoch: 850 Loss: 0.6629926562309265\n",
      "Epoch: 851 Loss: 0.6609184741973877\n",
      "Epoch: 852 Loss: 0.6609824895858765\n",
      "Epoch: 853 Loss: 0.6585676074028015\n",
      "Epoch: 854 Loss: 0.6588486433029175\n",
      "Epoch: 855 Loss: 0.656042754650116\n",
      "Epoch: 856 Loss: 0.6569038033485413\n",
      "Epoch: 857 Loss: 0.6541910171508789\n",
      "Epoch: 858 Loss: 0.6551501750946045\n",
      "Epoch: 859 Loss: 0.6534708738327026\n",
      "Epoch: 860 Loss: 0.6535186171531677\n",
      "Epoch: 861 Loss: 0.6507970690727234\n",
      "Epoch: 862 Loss: 0.6508631706237793\n",
      "Epoch: 863 Loss: 0.6498594284057617\n",
      "Epoch: 864 Loss: 0.6470307111740112\n",
      "Epoch: 865 Loss: 0.6476538181304932\n",
      "Epoch: 866 Loss: 0.645629346370697\n",
      "Epoch: 867 Loss: 0.6462995409965515\n",
      "Epoch: 868 Loss: 0.6451650857925415\n",
      "Epoch: 869 Loss: 0.642777681350708\n",
      "Epoch: 870 Loss: 0.6438441872596741\n",
      "Epoch: 871 Loss: 0.6404110193252563\n",
      "Epoch: 872 Loss: 0.6411638259887695\n",
      "Epoch: 873 Loss: 0.6389672160148621\n",
      "Epoch: 874 Loss: 0.6404957175254822\n",
      "Epoch: 875 Loss: 0.6367149353027344\n",
      "Epoch: 876 Loss: 0.6373319029808044\n",
      "Epoch: 877 Loss: 0.6363293528556824\n",
      "Epoch: 878 Loss: 0.6339048743247986\n",
      "Epoch: 879 Loss: 0.6354175209999084\n",
      "Epoch: 880 Loss: 0.6324726343154907\n",
      "Epoch: 881 Loss: 0.6333777904510498\n",
      "Epoch: 882 Loss: 0.6309667825698853\n",
      "Epoch: 883 Loss: 0.6302268505096436\n",
      "Epoch: 884 Loss: 0.6302998661994934\n",
      "Epoch: 885 Loss: 0.6278404593467712\n",
      "Epoch: 886 Loss: 0.6281712651252747\n",
      "Epoch: 887 Loss: 0.6265732049942017\n",
      "Epoch: 888 Loss: 0.6257813572883606\n",
      "Epoch: 889 Loss: 0.6251347064971924\n",
      "Epoch: 890 Loss: 0.6230015158653259\n",
      "Epoch: 891 Loss: 0.6229403018951416\n",
      "Epoch: 892 Loss: 0.6210710406303406\n",
      "Epoch: 893 Loss: 0.6214061975479126\n",
      "Epoch: 894 Loss: 0.6197952032089233\n",
      "Epoch: 895 Loss: 0.6188988089561462\n",
      "Epoch: 896 Loss: 0.6187925934791565\n",
      "Epoch: 897 Loss: 0.6163330078125\n",
      "Epoch: 898 Loss: 0.6168805956840515\n",
      "Epoch: 899 Loss: 0.6153262257575989\n",
      "Epoch: 900 Loss: 0.6144579648971558\n",
      "Epoch: 901 Loss: 0.6127952337265015\n",
      "Epoch: 902 Loss: 0.6125138998031616\n",
      "Epoch: 903 Loss: 0.6129283308982849\n",
      "Epoch: 904 Loss: 0.6121072769165039\n",
      "Epoch: 905 Loss: 0.610389232635498\n",
      "Epoch: 906 Loss: 0.6087923049926758\n",
      "Epoch: 907 Loss: 0.6083400249481201\n",
      "Epoch: 908 Loss: 0.6067049503326416\n",
      "Epoch: 909 Loss: 0.6068841814994812\n",
      "Epoch: 910 Loss: 0.605220377445221\n",
      "Epoch: 911 Loss: 0.6047946810722351\n",
      "Epoch: 912 Loss: 0.6038311123847961\n",
      "Epoch: 913 Loss: 0.6029857993125916\n",
      "Epoch: 914 Loss: 0.6022706627845764\n",
      "Epoch: 915 Loss: 0.6002015471458435\n",
      "Epoch: 916 Loss: 0.600945234298706\n",
      "Epoch: 917 Loss: 0.598863422870636\n",
      "Epoch: 918 Loss: 0.5991635918617249\n",
      "Epoch: 919 Loss: 0.5969710350036621\n",
      "Epoch: 920 Loss: 0.5968672633171082\n",
      "Epoch: 921 Loss: 0.5955466032028198\n",
      "Epoch: 922 Loss: 0.5948034524917603\n",
      "Epoch: 923 Loss: 0.5940461754798889\n",
      "Epoch: 924 Loss: 0.5930595397949219\n",
      "Epoch: 925 Loss: 0.592711329460144\n",
      "Epoch: 926 Loss: 0.5910497903823853\n",
      "Epoch: 927 Loss: 0.5905392169952393\n",
      "Epoch: 928 Loss: 0.5893527865409851\n",
      "Epoch: 929 Loss: 0.5885478854179382\n",
      "Epoch: 930 Loss: 0.5879614353179932\n",
      "Epoch: 931 Loss: 0.5869704484939575\n",
      "Epoch: 932 Loss: 0.5859780311584473\n",
      "Epoch: 933 Loss: 0.5847480893135071\n",
      "Epoch: 934 Loss: 0.584206759929657\n",
      "Epoch: 935 Loss: 0.5831136107444763\n",
      "Epoch: 936 Loss: 0.5828554630279541\n",
      "Epoch: 937 Loss: 0.5815044045448303\n",
      "Epoch: 938 Loss: 0.5808978080749512\n",
      "Epoch: 939 Loss: 0.5800791382789612\n",
      "Epoch: 940 Loss: 0.579343855381012\n",
      "Epoch: 941 Loss: 0.5780560374259949\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 942 Loss: 0.5768752694129944\n",
      "Epoch: 943 Loss: 0.5766890645027161\n",
      "Epoch: 944 Loss: 0.574909508228302\n",
      "Epoch: 945 Loss: 0.5747119784355164\n",
      "Epoch: 946 Loss: 0.5732300281524658\n",
      "Epoch: 947 Loss: 0.5732151865959167\n",
      "Epoch: 948 Loss: 0.5713380575180054\n",
      "Epoch: 949 Loss: 0.5713919401168823\n",
      "Epoch: 950 Loss: 0.5702218413352966\n",
      "Epoch: 951 Loss: 0.5690693259239197\n",
      "Epoch: 952 Loss: 0.5682777166366577\n",
      "Epoch: 953 Loss: 0.5681558847427368\n",
      "Epoch: 954 Loss: 0.5673319697380066\n",
      "Epoch: 955 Loss: 0.5662036538124084\n",
      "Epoch: 956 Loss: 0.5652552843093872\n",
      "Epoch: 957 Loss: 0.5641111135482788\n",
      "Epoch: 958 Loss: 0.5636101961135864\n",
      "Epoch: 959 Loss: 0.5629693865776062\n",
      "Epoch: 960 Loss: 0.5614902377128601\n",
      "Epoch: 961 Loss: 0.5607122182846069\n",
      "Epoch: 962 Loss: 0.5602133870124817\n",
      "Epoch: 963 Loss: 0.559083878993988\n",
      "Epoch: 964 Loss: 0.558386504650116\n",
      "Epoch: 965 Loss: 0.557803213596344\n",
      "Epoch: 966 Loss: 0.5569421648979187\n",
      "Epoch: 967 Loss: 0.5554912090301514\n",
      "Epoch: 968 Loss: 0.5554984211921692\n",
      "Epoch: 969 Loss: 0.5535005331039429\n",
      "Epoch: 970 Loss: 0.5538679957389832\n",
      "Epoch: 971 Loss: 0.55170077085495\n",
      "Epoch: 972 Loss: 0.5519660711288452\n",
      "Epoch: 973 Loss: 0.5509911775588989\n",
      "Epoch: 974 Loss: 0.5505711436271667\n",
      "Epoch: 975 Loss: 0.5496652126312256\n",
      "Epoch: 976 Loss: 0.5485290288925171\n",
      "Epoch: 977 Loss: 0.5478196144104004\n",
      "Epoch: 978 Loss: 0.5476526618003845\n",
      "Epoch: 979 Loss: 0.5466657280921936\n",
      "Epoch: 980 Loss: 0.5455791354179382\n",
      "Epoch: 981 Loss: 0.5449042916297913\n",
      "Epoch: 982 Loss: 0.5443891882896423\n",
      "Epoch: 983 Loss: 0.5434889793395996\n",
      "Epoch: 984 Loss: 0.5422112941741943\n",
      "Epoch: 985 Loss: 0.5414961576461792\n",
      "Epoch: 986 Loss: 0.5411043167114258\n",
      "Epoch: 987 Loss: 0.5396486520767212\n",
      "Epoch: 988 Loss: 0.5393855571746826\n",
      "Epoch: 989 Loss: 0.5381571054458618\n",
      "Epoch: 990 Loss: 0.5378226041793823\n",
      "Epoch: 991 Loss: 0.5367502570152283\n",
      "Epoch: 992 Loss: 0.5358860492706299\n",
      "Epoch: 993 Loss: 0.5350938439369202\n",
      "Epoch: 994 Loss: 0.534281849861145\n",
      "Epoch: 995 Loss: 0.533771812915802\n",
      "Epoch: 996 Loss: 0.5328412652015686\n",
      "Epoch: 997 Loss: 0.5322744846343994\n",
      "Epoch: 998 Loss: 0.5317434072494507\n",
      "Epoch: 999 Loss: 0.5308892726898193\n",
      "Epoch: 1000 Loss: 0.529608428478241\n",
      "Epoch: 1001 Loss: 0.52908855676651\n",
      "Epoch: 1002 Loss: 0.5280444622039795\n",
      "Epoch: 1003 Loss: 0.527510941028595\n",
      "Epoch: 1004 Loss: 0.5269861221313477\n",
      "Epoch: 1005 Loss: 0.5259512066841125\n",
      "Epoch: 1006 Loss: 0.5255213379859924\n",
      "Epoch: 1007 Loss: 0.5247889161109924\n",
      "Epoch: 1008 Loss: 0.522989809513092\n",
      "Epoch: 1009 Loss: 0.523266077041626\n",
      "Epoch: 1010 Loss: 0.5223901271820068\n",
      "Epoch: 1011 Loss: 0.5214393138885498\n",
      "Epoch: 1012 Loss: 0.5208967924118042\n",
      "Epoch: 1013 Loss: 0.5202175378799438\n",
      "Epoch: 1014 Loss: 0.5188876986503601\n",
      "Epoch: 1015 Loss: 0.5184786915779114\n",
      "Epoch: 1016 Loss: 0.5174180269241333\n",
      "Epoch: 1017 Loss: 0.5174757838249207\n",
      "Epoch: 1018 Loss: 0.5159308314323425\n",
      "Epoch: 1019 Loss: 0.515531599521637\n",
      "Epoch: 1020 Loss: 0.514779269695282\n",
      "Epoch: 1021 Loss: 0.5142658352851868\n",
      "Epoch: 1022 Loss: 0.5137369632720947\n",
      "Epoch: 1023 Loss: 0.5119512677192688\n",
      "Epoch: 1024 Loss: 0.5121092200279236\n",
      "Epoch: 1025 Loss: 0.5109959840774536\n",
      "Epoch: 1026 Loss: 0.510794997215271\n",
      "Epoch: 1027 Loss: 0.5088139176368713\n",
      "Epoch: 1028 Loss: 0.5098432302474976\n",
      "Epoch: 1029 Loss: 0.5085112452507019\n",
      "Epoch: 1030 Loss: 0.5076423287391663\n",
      "Epoch: 1031 Loss: 0.5064022541046143\n",
      "Epoch: 1032 Loss: 0.5059863328933716\n",
      "Epoch: 1033 Loss: 0.505666196346283\n",
      "Epoch: 1034 Loss: 0.5044045448303223\n",
      "Epoch: 1035 Loss: 0.504051148891449\n",
      "Epoch: 1036 Loss: 0.5033279061317444\n",
      "Epoch: 1037 Loss: 0.5019913911819458\n",
      "Epoch: 1038 Loss: 0.5016024708747864\n",
      "Epoch: 1039 Loss: 0.500859797000885\n",
      "Epoch: 1040 Loss: 0.5007660388946533\n",
      "Epoch: 1041 Loss: 0.49909141659736633\n",
      "Epoch: 1042 Loss: 0.49941593408584595\n",
      "Epoch: 1043 Loss: 0.49817684292793274\n",
      "Epoch: 1044 Loss: 0.49792322516441345\n",
      "Epoch: 1045 Loss: 0.49642467498779297\n",
      "Epoch: 1046 Loss: 0.4964858889579773\n",
      "Epoch: 1047 Loss: 0.4948873519897461\n",
      "Epoch: 1048 Loss: 0.4947871267795563\n",
      "Epoch: 1049 Loss: 0.4937756359577179\n",
      "Epoch: 1050 Loss: 0.49343812465667725\n",
      "Epoch: 1051 Loss: 0.4917762875556946\n",
      "Epoch: 1052 Loss: 0.49165481328964233\n",
      "Epoch: 1053 Loss: 0.4907532036304474\n",
      "Epoch: 1054 Loss: 0.4896422326564789\n",
      "Epoch: 1055 Loss: 0.4897562861442566\n",
      "Epoch: 1056 Loss: 0.48872488737106323\n",
      "Epoch: 1057 Loss: 0.48786768317222595\n",
      "Epoch: 1058 Loss: 0.48711714148521423\n",
      "Epoch: 1059 Loss: 0.485870361328125\n",
      "Epoch: 1060 Loss: 0.4857895076274872\n",
      "Epoch: 1061 Loss: 0.48495203256607056\n",
      "Epoch: 1062 Loss: 0.48390457034111023\n",
      "Epoch: 1063 Loss: 0.4832385182380676\n",
      "Epoch: 1064 Loss: 0.48320266604423523\n",
      "Epoch: 1065 Loss: 0.4822356700897217\n",
      "Epoch: 1066 Loss: 0.4812513589859009\n",
      "Epoch: 1067 Loss: 0.4816230535507202\n",
      "Epoch: 1068 Loss: 0.4801292419433594\n",
      "Epoch: 1069 Loss: 0.4788109362125397\n",
      "Epoch: 1070 Loss: 0.4784599840641022\n",
      "Epoch: 1071 Loss: 0.4783220589160919\n",
      "Epoch: 1072 Loss: 0.4771769940853119\n",
      "Epoch: 1073 Loss: 0.4763508141040802\n",
      "Epoch: 1074 Loss: 0.47607505321502686\n",
      "Epoch: 1075 Loss: 0.475661039352417\n",
      "Epoch: 1076 Loss: 0.47453567385673523\n",
      "Epoch: 1077 Loss: 0.47330376505851746\n",
      "Epoch: 1078 Loss: 0.4732852578163147\n",
      "Epoch: 1079 Loss: 0.4723914563655853\n",
      "Epoch: 1080 Loss: 0.4719860255718231\n",
      "Epoch: 1081 Loss: 0.4707224667072296\n",
      "Epoch: 1082 Loss: 0.4707903563976288\n",
      "Epoch: 1083 Loss: 0.4697037935256958\n",
      "Epoch: 1084 Loss: 0.4692517817020416\n",
      "Epoch: 1085 Loss: 0.4684649705886841\n",
      "Epoch: 1086 Loss: 0.4675559401512146\n",
      "Epoch: 1087 Loss: 0.4672565162181854\n",
      "Epoch: 1088 Loss: 0.46646997332572937\n",
      "Epoch: 1089 Loss: 0.46573159098625183\n",
      "Epoch: 1090 Loss: 0.46484169363975525\n",
      "Epoch: 1091 Loss: 0.4640659987926483\n",
      "Epoch: 1092 Loss: 0.4637468159198761\n",
      "Epoch: 1093 Loss: 0.4632200300693512\n",
      "Epoch: 1094 Loss: 0.46197035908699036\n",
      "Epoch: 1095 Loss: 0.4621478021144867\n",
      "Epoch: 1096 Loss: 0.4605018198490143\n",
      "Epoch: 1097 Loss: 0.4608108103275299\n",
      "Epoch: 1098 Loss: 0.4595291316509247\n",
      "Epoch: 1099 Loss: 0.45952606201171875\n",
      "Epoch: 1100 Loss: 0.458734929561615\n",
      "Epoch: 1101 Loss: 0.4578816294670105\n",
      "Epoch: 1102 Loss: 0.45810478925704956\n",
      "Epoch: 1103 Loss: 0.45671966671943665\n",
      "Epoch: 1104 Loss: 0.45586976408958435\n",
      "Epoch: 1105 Loss: 0.45593467354774475\n",
      "Epoch: 1106 Loss: 0.4546439051628113\n",
      "Epoch: 1107 Loss: 0.45382046699523926\n",
      "Epoch: 1108 Loss: 0.4534778892993927\n",
      "Epoch: 1109 Loss: 0.452907532453537\n",
      "Epoch: 1110 Loss: 0.45204368233680725\n",
      "Epoch: 1111 Loss: 0.4517810046672821\n",
      "Epoch: 1112 Loss: 0.45113885402679443\n",
      "Epoch: 1113 Loss: 0.4504222571849823\n",
      "Epoch: 1114 Loss: 0.4500269889831543\n",
      "Epoch: 1115 Loss: 0.44915348291397095\n",
      "Epoch: 1116 Loss: 0.44822776317596436\n",
      "Epoch: 1117 Loss: 0.44755619764328003\n",
      "Epoch: 1118 Loss: 0.446982204914093\n",
      "Epoch: 1119 Loss: 0.44645169377326965\n",
      "Epoch: 1120 Loss: 0.4455976188182831\n",
      "Epoch: 1121 Loss: 0.4451439082622528\n",
      "Epoch: 1122 Loss: 0.44454315304756165\n",
      "Epoch: 1123 Loss: 0.4441254436969757\n",
      "Epoch: 1124 Loss: 0.44322067499160767\n",
      "Epoch: 1125 Loss: 0.4421113431453705\n",
      "Epoch: 1126 Loss: 0.442178875207901\n",
      "Epoch: 1127 Loss: 0.4412628412246704\n",
      "Epoch: 1128 Loss: 0.44067785143852234\n",
      "Epoch: 1129 Loss: 0.43997129797935486\n",
      "Epoch: 1130 Loss: 0.4395972490310669\n",
      "Epoch: 1131 Loss: 0.4388653635978699\n",
      "Epoch: 1132 Loss: 0.43853750824928284\n",
      "Epoch: 1133 Loss: 0.43723177909851074\n",
      "Epoch: 1134 Loss: 0.437520295381546\n",
      "Epoch: 1135 Loss: 0.43656954169273376\n",
      "Epoch: 1136 Loss: 0.43574684858322144\n",
      "Epoch: 1137 Loss: 0.4350665509700775\n",
      "Epoch: 1138 Loss: 0.43488508462905884\n",
      "Epoch: 1139 Loss: 0.43350720405578613\n",
      "Epoch: 1140 Loss: 0.43323981761932373\n",
      "Epoch: 1141 Loss: 0.43263930082321167\n",
      "Epoch: 1142 Loss: 0.431672602891922\n",
      "Epoch: 1143 Loss: 0.4314834773540497\n",
      "Epoch: 1144 Loss: 0.4309360980987549\n",
      "Epoch: 1145 Loss: 0.42995843291282654\n",
      "Epoch: 1146 Loss: 0.4298768937587738\n",
      "Epoch: 1147 Loss: 0.4286057949066162\n",
      "Epoch: 1148 Loss: 0.4283048212528229\n",
      "Epoch: 1149 Loss: 0.42769956588745117\n",
      "Epoch: 1150 Loss: 0.4267067611217499\n",
      "Epoch: 1151 Loss: 0.42676904797554016\n",
      "Epoch: 1152 Loss: 0.42585399746894836\n",
      "Epoch: 1153 Loss: 0.4248896837234497\n",
      "Epoch: 1154 Loss: 0.4247940480709076\n",
      "Epoch: 1155 Loss: 0.4240984320640564\n",
      "Epoch: 1156 Loss: 0.4234631359577179\n",
      "Epoch: 1157 Loss: 0.423317015171051\n",
      "Epoch: 1158 Loss: 0.4220590889453888\n",
      "Epoch: 1159 Loss: 0.42194864153862\n",
      "Epoch: 1160 Loss: 0.42089641094207764\n",
      "Epoch: 1161 Loss: 0.42036107182502747\n",
      "Epoch: 1162 Loss: 0.4198724925518036\n",
      "Epoch: 1163 Loss: 0.41914525628089905\n",
      "Epoch: 1164 Loss: 0.4184037744998932\n",
      "Epoch: 1165 Loss: 0.417906790971756\n",
      "Epoch: 1166 Loss: 0.41749903559684753\n",
      "Epoch: 1167 Loss: 0.4168854355812073\n",
      "Epoch: 1168 Loss: 0.41618549823760986\n",
      "Epoch: 1169 Loss: 0.41523483395576477\n",
      "Epoch: 1170 Loss: 0.41486647725105286\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1171 Loss: 0.41417765617370605\n",
      "Epoch: 1172 Loss: 0.4141392111778259\n",
      "Epoch: 1173 Loss: 0.4130527079105377\n",
      "Epoch: 1174 Loss: 0.4125176966190338\n",
      "Epoch: 1175 Loss: 0.41161713004112244\n",
      "Epoch: 1176 Loss: 0.4111669659614563\n",
      "Epoch: 1177 Loss: 0.41089871525764465\n",
      "Epoch: 1178 Loss: 0.4098060727119446\n",
      "Epoch: 1179 Loss: 0.4094087779521942\n",
      "Epoch: 1180 Loss: 0.4089564383029938\n",
      "Epoch: 1181 Loss: 0.40804198384284973\n",
      "Epoch: 1182 Loss: 0.40796804428100586\n",
      "Epoch: 1183 Loss: 0.40726423263549805\n",
      "Epoch: 1184 Loss: 0.4063868820667267\n",
      "Epoch: 1185 Loss: 0.4060220718383789\n",
      "Epoch: 1186 Loss: 0.4052112400531769\n",
      "Epoch: 1187 Loss: 0.4045411944389343\n",
      "Epoch: 1188 Loss: 0.4041464626789093\n",
      "Epoch: 1189 Loss: 0.403473824262619\n",
      "Epoch: 1190 Loss: 0.4029342234134674\n",
      "Epoch: 1191 Loss: 0.40284866094589233\n",
      "Epoch: 1192 Loss: 0.4018949866294861\n",
      "Epoch: 1193 Loss: 0.40180134773254395\n",
      "Epoch: 1194 Loss: 0.4005822539329529\n",
      "Epoch: 1195 Loss: 0.40030086040496826\n",
      "Epoch: 1196 Loss: 0.3991815745830536\n",
      "Epoch: 1197 Loss: 0.3986361026763916\n",
      "Epoch: 1198 Loss: 0.3979930877685547\n",
      "Epoch: 1199 Loss: 0.3976297080516815\n",
      "Epoch: 1200 Loss: 0.397481232881546\n",
      "Epoch: 1201 Loss: 0.3964258134365082\n",
      "Epoch: 1202 Loss: 0.3955300450325012\n",
      "Epoch: 1203 Loss: 0.3951661288738251\n",
      "Epoch: 1204 Loss: 0.39473846554756165\n",
      "Epoch: 1205 Loss: 0.39429038763046265\n",
      "Epoch: 1206 Loss: 0.3940091133117676\n",
      "Epoch: 1207 Loss: 0.39288318157196045\n",
      "Epoch: 1208 Loss: 0.3924161195755005\n",
      "Epoch: 1209 Loss: 0.3919912278652191\n",
      "Epoch: 1210 Loss: 0.3910887837409973\n",
      "Epoch: 1211 Loss: 0.39068976044654846\n",
      "Epoch: 1212 Loss: 0.39058443903923035\n",
      "Epoch: 1213 Loss: 0.389156311750412\n",
      "Epoch: 1214 Loss: 0.3893011510372162\n",
      "Epoch: 1215 Loss: 0.3885893225669861\n",
      "Epoch: 1216 Loss: 0.3880869448184967\n",
      "Epoch: 1217 Loss: 0.3873539865016937\n",
      "Epoch: 1218 Loss: 0.3864673674106598\n",
      "Epoch: 1219 Loss: 0.3859502673149109\n",
      "Epoch: 1220 Loss: 0.3858347535133362\n",
      "Epoch: 1221 Loss: 0.3846863806247711\n",
      "Epoch: 1222 Loss: 0.3847123682498932\n",
      "Epoch: 1223 Loss: 0.38370785117149353\n",
      "Epoch: 1224 Loss: 0.3830403983592987\n",
      "Epoch: 1225 Loss: 0.3824974298477173\n",
      "Epoch: 1226 Loss: 0.38252708315849304\n",
      "Epoch: 1227 Loss: 0.38143038749694824\n",
      "Epoch: 1228 Loss: 0.38148224353790283\n",
      "Epoch: 1229 Loss: 0.3807467520236969\n",
      "Epoch: 1230 Loss: 0.3801737129688263\n",
      "Epoch: 1231 Loss: 0.37961506843566895\n",
      "Epoch: 1232 Loss: 0.37880319356918335\n",
      "Epoch: 1233 Loss: 0.37848755717277527\n",
      "Epoch: 1234 Loss: 0.37805166840553284\n",
      "Epoch: 1235 Loss: 0.37750568985939026\n",
      "Epoch: 1236 Loss: 0.37701284885406494\n",
      "Epoch: 1237 Loss: 0.3763194978237152\n",
      "Epoch: 1238 Loss: 0.3761923611164093\n",
      "Epoch: 1239 Loss: 0.3754396438598633\n",
      "Epoch: 1240 Loss: 0.37468400597572327\n",
      "Epoch: 1241 Loss: 0.37430626153945923\n",
      "Epoch: 1242 Loss: 0.37394875288009644\n",
      "Epoch: 1243 Loss: 0.3733367621898651\n",
      "Epoch: 1244 Loss: 0.37256819009780884\n",
      "Epoch: 1245 Loss: 0.3725138008594513\n",
      "Epoch: 1246 Loss: 0.3718867003917694\n",
      "Epoch: 1247 Loss: 0.3713168203830719\n",
      "Epoch: 1248 Loss: 0.37126654386520386\n",
      "Epoch: 1249 Loss: 0.3703015446662903\n",
      "Epoch: 1250 Loss: 0.36934220790863037\n",
      "Epoch: 1251 Loss: 0.36902302503585815\n",
      "Epoch: 1252 Loss: 0.36816170811653137\n",
      "Epoch: 1253 Loss: 0.3681943416595459\n",
      "Epoch: 1254 Loss: 0.3675509989261627\n",
      "Epoch: 1255 Loss: 0.36659732460975647\n",
      "Epoch: 1256 Loss: 0.3664311170578003\n",
      "Epoch: 1257 Loss: 0.36601850390434265\n",
      "Epoch: 1258 Loss: 0.36494162678718567\n",
      "Epoch: 1259 Loss: 0.3649161159992218\n",
      "Epoch: 1260 Loss: 0.3650045394897461\n",
      "Epoch: 1261 Loss: 0.3637581467628479\n",
      "Epoch: 1262 Loss: 0.3635181486606598\n",
      "Epoch: 1263 Loss: 0.36300063133239746\n",
      "Epoch: 1264 Loss: 0.3625028133392334\n",
      "Epoch: 1265 Loss: 0.3620234727859497\n",
      "Epoch: 1266 Loss: 0.3615584075450897\n",
      "Epoch: 1267 Loss: 0.3611829876899719\n",
      "Epoch: 1268 Loss: 0.3604060113430023\n",
      "Epoch: 1269 Loss: 0.35965728759765625\n",
      "Epoch: 1270 Loss: 0.35972872376441956\n",
      "Epoch: 1271 Loss: 0.3585251271724701\n",
      "Epoch: 1272 Loss: 0.35838034749031067\n",
      "Epoch: 1273 Loss: 0.3579622507095337\n",
      "Epoch: 1274 Loss: 0.35770779848098755\n",
      "Epoch: 1275 Loss: 0.3571753203868866\n",
      "Epoch: 1276 Loss: 0.35626548528671265\n",
      "Epoch: 1277 Loss: 0.35574012994766235\n",
      "Epoch: 1278 Loss: 0.35544344782829285\n",
      "Epoch: 1279 Loss: 0.3551292419433594\n",
      "Epoch: 1280 Loss: 0.35433652997016907\n",
      "Epoch: 1281 Loss: 0.35349419713020325\n",
      "Epoch: 1282 Loss: 0.3532772958278656\n",
      "Epoch: 1283 Loss: 0.3531184196472168\n",
      "Epoch: 1284 Loss: 0.35271283984184265\n",
      "Epoch: 1285 Loss: 0.35188159346580505\n",
      "Epoch: 1286 Loss: 0.35091015696525574\n",
      "Epoch: 1287 Loss: 0.35074150562286377\n",
      "Epoch: 1288 Loss: 0.3505858778953552\n",
      "Epoch: 1289 Loss: 0.35010620951652527\n",
      "Epoch: 1290 Loss: 0.349754273891449\n",
      "Epoch: 1291 Loss: 0.3490108549594879\n",
      "Epoch: 1292 Loss: 0.34877482056617737\n",
      "Epoch: 1293 Loss: 0.34769904613494873\n",
      "Epoch: 1294 Loss: 0.347084641456604\n",
      "Epoch: 1295 Loss: 0.3470446467399597\n",
      "Epoch: 1296 Loss: 0.34618014097213745\n",
      "Epoch: 1297 Loss: 0.3464536964893341\n",
      "Epoch: 1298 Loss: 0.3458460569381714\n",
      "Epoch: 1299 Loss: 0.3450084626674652\n",
      "Epoch: 1300 Loss: 0.34487661719322205\n",
      "Epoch: 1301 Loss: 0.3442683219909668\n",
      "Epoch: 1302 Loss: 0.3435211777687073\n",
      "Epoch: 1303 Loss: 0.3431253731250763\n",
      "Epoch: 1304 Loss: 0.34310492873191833\n",
      "Epoch: 1305 Loss: 0.3423328101634979\n",
      "Epoch: 1306 Loss: 0.3414594531059265\n",
      "Epoch: 1307 Loss: 0.3407077193260193\n",
      "Epoch: 1308 Loss: 0.3405650854110718\n",
      "Epoch: 1309 Loss: 0.34054166078567505\n",
      "Epoch: 1310 Loss: 0.3397659957408905\n",
      "Epoch: 1311 Loss: 0.33932438492774963\n",
      "Epoch: 1312 Loss: 0.33866819739341736\n",
      "Epoch: 1313 Loss: 0.338412344455719\n",
      "Epoch: 1314 Loss: 0.33765363693237305\n",
      "Epoch: 1315 Loss: 0.33740630745887756\n",
      "Epoch: 1316 Loss: 0.33680009841918945\n",
      "Epoch: 1317 Loss: 0.3365783393383026\n",
      "Epoch: 1318 Loss: 0.3358977437019348\n",
      "Epoch: 1319 Loss: 0.3356630802154541\n",
      "Epoch: 1320 Loss: 0.334847629070282\n",
      "Epoch: 1321 Loss: 0.33475935459136963\n",
      "Epoch: 1322 Loss: 0.3344358205795288\n",
      "Epoch: 1323 Loss: 0.3335672914981842\n",
      "Epoch: 1324 Loss: 0.3330095112323761\n",
      "Epoch: 1325 Loss: 0.3323962986469269\n",
      "Epoch: 1326 Loss: 0.3324258327484131\n",
      "Epoch: 1327 Loss: 0.331650048494339\n",
      "Epoch: 1328 Loss: 0.3313504159450531\n",
      "Epoch: 1329 Loss: 0.33104246854782104\n",
      "Epoch: 1330 Loss: 0.32996508479118347\n",
      "Epoch: 1331 Loss: 0.3296065628528595\n",
      "Epoch: 1332 Loss: 0.32972028851509094\n",
      "Epoch: 1333 Loss: 0.3290826380252838\n",
      "Epoch: 1334 Loss: 0.3285422921180725\n",
      "Epoch: 1335 Loss: 0.3282775282859802\n",
      "Epoch: 1336 Loss: 0.3276320695877075\n",
      "Epoch: 1337 Loss: 0.32690659165382385\n",
      "Epoch: 1338 Loss: 0.3266492187976837\n",
      "Epoch: 1339 Loss: 0.3261965811252594\n",
      "Epoch: 1340 Loss: 0.3257092535495758\n",
      "Epoch: 1341 Loss: 0.3251951336860657\n",
      "Epoch: 1342 Loss: 0.3246808350086212\n",
      "Epoch: 1343 Loss: 0.32441768050193787\n",
      "Epoch: 1344 Loss: 0.32381555438041687\n",
      "Epoch: 1345 Loss: 0.32349732518196106\n",
      "Epoch: 1346 Loss: 0.3232337534427643\n",
      "Epoch: 1347 Loss: 0.32289570569992065\n",
      "Epoch: 1348 Loss: 0.3220885396003723\n",
      "Epoch: 1349 Loss: 0.3215884864330292\n",
      "Epoch: 1350 Loss: 0.3213016390800476\n",
      "Epoch: 1351 Loss: 0.32051756978034973\n",
      "Epoch: 1352 Loss: 0.3198743760585785\n",
      "Epoch: 1353 Loss: 0.31951892375946045\n",
      "Epoch: 1354 Loss: 0.31949928402900696\n",
      "Epoch: 1355 Loss: 0.3188193440437317\n",
      "Epoch: 1356 Loss: 0.31852713227272034\n",
      "Epoch: 1357 Loss: 0.3178460896015167\n",
      "Epoch: 1358 Loss: 0.31762900948524475\n",
      "Epoch: 1359 Loss: 0.31707918643951416\n",
      "Epoch: 1360 Loss: 0.3170599937438965\n",
      "Epoch: 1361 Loss: 0.31651535630226135\n",
      "Epoch: 1362 Loss: 0.31582653522491455\n",
      "Epoch: 1363 Loss: 0.3154059052467346\n",
      "Epoch: 1364 Loss: 0.31496191024780273\n",
      "Epoch: 1365 Loss: 0.31466883420944214\n",
      "Epoch: 1366 Loss: 0.3144570291042328\n",
      "Epoch: 1367 Loss: 0.3140110969543457\n",
      "Epoch: 1368 Loss: 0.3135107159614563\n",
      "Epoch: 1369 Loss: 0.3128182590007782\n",
      "Epoch: 1370 Loss: 0.3125019073486328\n",
      "Epoch: 1371 Loss: 0.31192660331726074\n",
      "Epoch: 1372 Loss: 0.3115485608577728\n",
      "Epoch: 1373 Loss: 0.31139975786209106\n",
      "Epoch: 1374 Loss: 0.3108198940753937\n",
      "Epoch: 1375 Loss: 0.3101564049720764\n",
      "Epoch: 1376 Loss: 0.3100915551185608\n",
      "Epoch: 1377 Loss: 0.309414267539978\n",
      "Epoch: 1378 Loss: 0.30888211727142334\n",
      "Epoch: 1379 Loss: 0.3086020350456238\n",
      "Epoch: 1380 Loss: 0.30840739607810974\n",
      "Epoch: 1381 Loss: 0.3079332709312439\n",
      "Epoch: 1382 Loss: 0.3071223199367523\n",
      "Epoch: 1383 Loss: 0.3067125380039215\n",
      "Epoch: 1384 Loss: 0.3066815137863159\n",
      "Epoch: 1385 Loss: 0.30622974038124084\n",
      "Epoch: 1386 Loss: 0.30570268630981445\n",
      "Epoch: 1387 Loss: 0.30508098006248474\n",
      "Epoch: 1388 Loss: 0.3049386441707611\n",
      "Epoch: 1389 Loss: 0.3046307861804962\n",
      "Epoch: 1390 Loss: 0.3040392994880676\n",
      "Epoch: 1391 Loss: 0.3037140369415283\n",
      "Epoch: 1392 Loss: 0.3034244477748871\n",
      "Epoch: 1393 Loss: 0.3027525246143341\n",
      "Epoch: 1394 Loss: 0.30260899662971497\n",
      "Epoch: 1395 Loss: 0.3021891713142395\n",
      "Epoch: 1396 Loss: 0.3016366958618164\n",
      "Epoch: 1397 Loss: 0.3011985719203949\n",
      "Epoch: 1398 Loss: 0.30110448598861694\n",
      "Epoch: 1399 Loss: 0.3005543649196625\n",
      "Epoch: 1400 Loss: 0.3002641797065735\n",
      "Epoch: 1401 Loss: 0.29954931139945984\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1402 Loss: 0.29939204454421997\n",
      "Epoch: 1403 Loss: 0.29862555861473083\n",
      "Epoch: 1404 Loss: 0.29839155077934265\n",
      "Epoch: 1405 Loss: 0.2978818416595459\n",
      "Epoch: 1406 Loss: 0.29752880334854126\n",
      "Epoch: 1407 Loss: 0.29764828085899353\n",
      "Epoch: 1408 Loss: 0.296686589717865\n",
      "Epoch: 1409 Loss: 0.2964232861995697\n",
      "Epoch: 1410 Loss: 0.29574665427207947\n",
      "Epoch: 1411 Loss: 0.2957584857940674\n",
      "Epoch: 1412 Loss: 0.29485228657722473\n",
      "Epoch: 1413 Loss: 0.29470449686050415\n",
      "Epoch: 1414 Loss: 0.2943603992462158\n",
      "Epoch: 1415 Loss: 0.29390591382980347\n",
      "Epoch: 1416 Loss: 0.29357513785362244\n",
      "Epoch: 1417 Loss: 0.2933928668498993\n",
      "Epoch: 1418 Loss: 0.2927872836589813\n",
      "Epoch: 1419 Loss: 0.2922026216983795\n",
      "Epoch: 1420 Loss: 0.29222404956817627\n",
      "Epoch: 1421 Loss: 0.29159119725227356\n",
      "Epoch: 1422 Loss: 0.2912271320819855\n",
      "Epoch: 1423 Loss: 0.29068514704704285\n",
      "Epoch: 1424 Loss: 0.29047834873199463\n",
      "Epoch: 1425 Loss: 0.2899259626865387\n",
      "Epoch: 1426 Loss: 0.2897847294807434\n",
      "Epoch: 1427 Loss: 0.2894900441169739\n",
      "Epoch: 1428 Loss: 0.2889467179775238\n",
      "Epoch: 1429 Loss: 0.2888067662715912\n",
      "Epoch: 1430 Loss: 0.2882559597492218\n",
      "Epoch: 1431 Loss: 0.28796276450157166\n",
      "Epoch: 1432 Loss: 0.28742867708206177\n",
      "Epoch: 1433 Loss: 0.28731998801231384\n",
      "Epoch: 1434 Loss: 0.2867060899734497\n",
      "Epoch: 1435 Loss: 0.2864408791065216\n",
      "Epoch: 1436 Loss: 0.28615692257881165\n",
      "Epoch: 1437 Loss: 0.28561973571777344\n",
      "Epoch: 1438 Loss: 0.2852543890476227\n",
      "Epoch: 1439 Loss: 0.2848889231681824\n",
      "Epoch: 1440 Loss: 0.2846914231777191\n",
      "Epoch: 1441 Loss: 0.2841792106628418\n",
      "Epoch: 1442 Loss: 0.28398391604423523\n",
      "Epoch: 1443 Loss: 0.28347864747047424\n",
      "Epoch: 1444 Loss: 0.2829660475254059\n",
      "Epoch: 1445 Loss: 0.28273528814315796\n",
      "Epoch: 1446 Loss: 0.2821609675884247\n",
      "Epoch: 1447 Loss: 0.2818852961063385\n",
      "Epoch: 1448 Loss: 0.28163161873817444\n",
      "Epoch: 1449 Loss: 0.28125128149986267\n",
      "Epoch: 1450 Loss: 0.28105759620666504\n",
      "Epoch: 1451 Loss: 0.2806459367275238\n",
      "Epoch: 1452 Loss: 0.28045031428337097\n",
      "Epoch: 1453 Loss: 0.27976784110069275\n",
      "Epoch: 1454 Loss: 0.2796628773212433\n",
      "Epoch: 1455 Loss: 0.27910351753234863\n",
      "Epoch: 1456 Loss: 0.2788756489753723\n",
      "Epoch: 1457 Loss: 0.2783362865447998\n",
      "Epoch: 1458 Loss: 0.2780338227748871\n",
      "Epoch: 1459 Loss: 0.27777889370918274\n",
      "Epoch: 1460 Loss: 0.27747777104377747\n",
      "Epoch: 1461 Loss: 0.27719613909721375\n",
      "Epoch: 1462 Loss: 0.27656248211860657\n",
      "Epoch: 1463 Loss: 0.27658841013908386\n",
      "Epoch: 1464 Loss: 0.2757984399795532\n",
      "Epoch: 1465 Loss: 0.2761388421058655\n",
      "Epoch: 1466 Loss: 0.27522164583206177\n",
      "Epoch: 1467 Loss: 0.2751440405845642\n",
      "Epoch: 1468 Loss: 0.2745217978954315\n",
      "Epoch: 1469 Loss: 0.2740626335144043\n",
      "Epoch: 1470 Loss: 0.2738729417324066\n",
      "Epoch: 1471 Loss: 0.27338704466819763\n",
      "Epoch: 1472 Loss: 0.2732656002044678\n",
      "Epoch: 1473 Loss: 0.27282121777534485\n",
      "Epoch: 1474 Loss: 0.27259939908981323\n",
      "Epoch: 1475 Loss: 0.272119402885437\n",
      "Epoch: 1476 Loss: 0.2716083824634552\n",
      "Epoch: 1477 Loss: 0.2713887393474579\n",
      "Epoch: 1478 Loss: 0.27087000012397766\n",
      "Epoch: 1479 Loss: 0.2707708179950714\n",
      "Epoch: 1480 Loss: 0.27065080404281616\n",
      "Epoch: 1481 Loss: 0.2699517607688904\n",
      "Epoch: 1482 Loss: 0.26984167098999023\n",
      "Epoch: 1483 Loss: 0.2693358361721039\n",
      "Epoch: 1484 Loss: 0.26911288499832153\n",
      "Epoch: 1485 Loss: 0.26866060495376587\n",
      "Epoch: 1486 Loss: 0.26850977540016174\n",
      "Epoch: 1487 Loss: 0.26788580417633057\n",
      "Epoch: 1488 Loss: 0.26766157150268555\n",
      "Epoch: 1489 Loss: 0.26753905415534973\n",
      "Epoch: 1490 Loss: 0.2670157253742218\n",
      "Epoch: 1491 Loss: 0.2669539749622345\n",
      "Epoch: 1492 Loss: 0.266133189201355\n",
      "Epoch: 1493 Loss: 0.26615628600120544\n",
      "Epoch: 1494 Loss: 0.2657780945301056\n",
      "Epoch: 1495 Loss: 0.26530754566192627\n",
      "Epoch: 1496 Loss: 0.2652694284915924\n",
      "Epoch: 1497 Loss: 0.26477858424186707\n",
      "Epoch: 1498 Loss: 0.26432961225509644\n",
      "Epoch: 1499 Loss: 0.264055073261261\n",
      "Epoch: 1500 Loss: 0.26380500197410583\n",
      "Epoch: 1501 Loss: 0.26362618803977966\n",
      "Epoch: 1502 Loss: 0.2631106674671173\n",
      "Epoch: 1503 Loss: 0.2626982629299164\n",
      "Epoch: 1504 Loss: 0.2627906799316406\n",
      "Epoch: 1505 Loss: 0.2620930075645447\n",
      "Epoch: 1506 Loss: 0.2619181275367737\n",
      "Epoch: 1507 Loss: 0.26164111495018005\n",
      "Epoch: 1508 Loss: 0.26109060645103455\n",
      "Epoch: 1509 Loss: 0.2608353793621063\n",
      "Epoch: 1510 Loss: 0.2607525885105133\n",
      "Epoch: 1511 Loss: 0.2600753605365753\n",
      "Epoch: 1512 Loss: 0.26016148924827576\n",
      "Epoch: 1513 Loss: 0.25942716002464294\n",
      "Epoch: 1514 Loss: 0.25920218229293823\n",
      "Epoch: 1515 Loss: 0.25905752182006836\n",
      "Epoch: 1516 Loss: 0.25833097100257874\n",
      "Epoch: 1517 Loss: 0.2582799196243286\n",
      "Epoch: 1518 Loss: 0.25782084465026855\n",
      "Epoch: 1519 Loss: 0.25756344199180603\n",
      "Epoch: 1520 Loss: 0.25726068019866943\n",
      "Epoch: 1521 Loss: 0.257060170173645\n",
      "Epoch: 1522 Loss: 0.2570611536502838\n",
      "Epoch: 1523 Loss: 0.2564880847930908\n",
      "Epoch: 1524 Loss: 0.25653529167175293\n",
      "Epoch: 1525 Loss: 0.25570374727249146\n",
      "Epoch: 1526 Loss: 0.2557053864002228\n",
      "Epoch: 1527 Loss: 0.2552562654018402\n",
      "Epoch: 1528 Loss: 0.25525909662246704\n",
      "Epoch: 1529 Loss: 0.25455257296562195\n",
      "Epoch: 1530 Loss: 0.25414302945137024\n",
      "Epoch: 1531 Loss: 0.2540394961833954\n",
      "Epoch: 1532 Loss: 0.253561407327652\n",
      "Epoch: 1533 Loss: 0.2533528208732605\n",
      "Epoch: 1534 Loss: 0.25297799706459045\n",
      "Epoch: 1535 Loss: 0.25276121497154236\n",
      "Epoch: 1536 Loss: 0.2526281177997589\n",
      "Epoch: 1537 Loss: 0.25219056010246277\n",
      "Epoch: 1538 Loss: 0.25203561782836914\n",
      "Epoch: 1539 Loss: 0.2515052855014801\n",
      "Epoch: 1540 Loss: 0.2513437569141388\n",
      "Epoch: 1541 Loss: 0.2508307099342346\n",
      "Epoch: 1542 Loss: 0.25087928771972656\n",
      "Epoch: 1543 Loss: 0.250214159488678\n",
      "Epoch: 1544 Loss: 0.25012049078941345\n",
      "Epoch: 1545 Loss: 0.24969987571239471\n",
      "Epoch: 1546 Loss: 0.24963629245758057\n",
      "Epoch: 1547 Loss: 0.24934779107570648\n",
      "Epoch: 1548 Loss: 0.24881385266780853\n",
      "Epoch: 1549 Loss: 0.2485162615776062\n",
      "Epoch: 1550 Loss: 0.24846653640270233\n",
      "Epoch: 1551 Loss: 0.2479032427072525\n",
      "Epoch: 1552 Loss: 0.24782533943653107\n",
      "Epoch: 1553 Loss: 0.24742065370082855\n",
      "Epoch: 1554 Loss: 0.24719469249248505\n",
      "Epoch: 1555 Loss: 0.24675367772579193\n",
      "Epoch: 1556 Loss: 0.2467062622308731\n",
      "Epoch: 1557 Loss: 0.24617671966552734\n",
      "Epoch: 1558 Loss: 0.2460823655128479\n",
      "Epoch: 1559 Loss: 0.24553588032722473\n",
      "Epoch: 1560 Loss: 0.2457137405872345\n",
      "Epoch: 1561 Loss: 0.24513879418373108\n",
      "Epoch: 1562 Loss: 0.2446235865354538\n",
      "Epoch: 1563 Loss: 0.24455656111240387\n",
      "Epoch: 1564 Loss: 0.24422067403793335\n",
      "Epoch: 1565 Loss: 0.24390994012355804\n",
      "Epoch: 1566 Loss: 0.24363332986831665\n",
      "Epoch: 1567 Loss: 0.24330246448516846\n",
      "Epoch: 1568 Loss: 0.24309727549552917\n",
      "Epoch: 1569 Loss: 0.24271920323371887\n",
      "Epoch: 1570 Loss: 0.24246306717395782\n",
      "Epoch: 1571 Loss: 0.24230395257472992\n",
      "Epoch: 1572 Loss: 0.24187447130680084\n",
      "Epoch: 1573 Loss: 0.24173066020011902\n",
      "Epoch: 1574 Loss: 0.24131403863430023\n",
      "Epoch: 1575 Loss: 0.2412417232990265\n",
      "Epoch: 1576 Loss: 0.2407202422618866\n",
      "Epoch: 1577 Loss: 0.2404310405254364\n",
      "Epoch: 1578 Loss: 0.2402697205543518\n",
      "Epoch: 1579 Loss: 0.23991656303405762\n",
      "Epoch: 1580 Loss: 0.2398553341627121\n",
      "Epoch: 1581 Loss: 0.23929227888584137\n",
      "Epoch: 1582 Loss: 0.23913073539733887\n",
      "Epoch: 1583 Loss: 0.23878872394561768\n",
      "Epoch: 1584 Loss: 0.23848558962345123\n",
      "Epoch: 1585 Loss: 0.2385353296995163\n",
      "Epoch: 1586 Loss: 0.23803499341011047\n",
      "Epoch: 1587 Loss: 0.23774881660938263\n",
      "Epoch: 1588 Loss: 0.23755815625190735\n",
      "Epoch: 1589 Loss: 0.2371261715888977\n",
      "Epoch: 1590 Loss: 0.2368766814470291\n",
      "Epoch: 1591 Loss: 0.23660524189472198\n",
      "Epoch: 1592 Loss: 0.23656046390533447\n",
      "Epoch: 1593 Loss: 0.23603864014148712\n",
      "Epoch: 1594 Loss: 0.23587943613529205\n",
      "Epoch: 1595 Loss: 0.23554907739162445\n",
      "Epoch: 1596 Loss: 0.23527061939239502\n",
      "Epoch: 1597 Loss: 0.23477676510810852\n",
      "Epoch: 1598 Loss: 0.234871506690979\n",
      "Epoch: 1599 Loss: 0.23435813188552856\n",
      "Epoch: 1600 Loss: 0.23429414629936218\n",
      "Epoch: 1601 Loss: 0.23385418951511383\n",
      "Epoch: 1602 Loss: 0.23378868401050568\n",
      "Epoch: 1603 Loss: 0.23332495987415314\n",
      "Epoch: 1604 Loss: 0.23327413201332092\n",
      "Epoch: 1605 Loss: 0.23275434970855713\n",
      "Epoch: 1606 Loss: 0.23260918259620667\n",
      "Epoch: 1607 Loss: 0.2321396768093109\n",
      "Epoch: 1608 Loss: 0.23196929693222046\n",
      "Epoch: 1609 Loss: 0.2317045032978058\n",
      "Epoch: 1610 Loss: 0.2315765917301178\n",
      "Epoch: 1611 Loss: 0.23103906214237213\n",
      "Epoch: 1612 Loss: 0.23089906573295593\n",
      "Epoch: 1613 Loss: 0.23066475987434387\n",
      "Epoch: 1614 Loss: 0.23035357892513275\n",
      "Epoch: 1615 Loss: 0.23010072112083435\n",
      "Epoch: 1616 Loss: 0.22982336580753326\n",
      "Epoch: 1617 Loss: 0.22961148619651794\n",
      "Epoch: 1618 Loss: 0.22920088469982147\n",
      "Epoch: 1619 Loss: 0.2291100025177002\n",
      "Epoch: 1620 Loss: 0.22892211377620697\n",
      "Epoch: 1621 Loss: 0.22832909226417542\n",
      "Epoch: 1622 Loss: 0.22817811369895935\n",
      "Epoch: 1623 Loss: 0.22785399854183197\n",
      "Epoch: 1624 Loss: 0.22785092890262604\n",
      "Epoch: 1625 Loss: 0.22734865546226501\n",
      "Epoch: 1626 Loss: 0.22733648121356964\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1627 Loss: 0.22674560546875\n",
      "Epoch: 1628 Loss: 0.22668349742889404\n",
      "Epoch: 1629 Loss: 0.2263282835483551\n",
      "Epoch: 1630 Loss: 0.2261933535337448\n",
      "Epoch: 1631 Loss: 0.22580178081989288\n",
      "Epoch: 1632 Loss: 0.22570860385894775\n",
      "Epoch: 1633 Loss: 0.22543734312057495\n",
      "Epoch: 1634 Loss: 0.2251395583152771\n",
      "Epoch: 1635 Loss: 0.22485210001468658\n",
      "Epoch: 1636 Loss: 0.2246858775615692\n",
      "Epoch: 1637 Loss: 0.22418570518493652\n",
      "Epoch: 1638 Loss: 0.22408628463745117\n",
      "Epoch: 1639 Loss: 0.22388315200805664\n",
      "Epoch: 1640 Loss: 0.22342777252197266\n",
      "Epoch: 1641 Loss: 0.22338324785232544\n",
      "Epoch: 1642 Loss: 0.2229192554950714\n",
      "Epoch: 1643 Loss: 0.22274376451969147\n",
      "Epoch: 1644 Loss: 0.22254346311092377\n",
      "Epoch: 1645 Loss: 0.2222624570131302\n",
      "Epoch: 1646 Loss: 0.22198820114135742\n",
      "Epoch: 1647 Loss: 0.22187311947345734\n",
      "Epoch: 1648 Loss: 0.22135613858699799\n",
      "Epoch: 1649 Loss: 0.2214001566171646\n",
      "Epoch: 1650 Loss: 0.22109432518482208\n",
      "Epoch: 1651 Loss: 0.22070461511611938\n",
      "Epoch: 1652 Loss: 0.22048650681972504\n",
      "Epoch: 1653 Loss: 0.22015434503555298\n",
      "Epoch: 1654 Loss: 0.21992450952529907\n",
      "Epoch: 1655 Loss: 0.21989914774894714\n",
      "Epoch: 1656 Loss: 0.2193748652935028\n",
      "Epoch: 1657 Loss: 0.21952597796916962\n",
      "Epoch: 1658 Loss: 0.21906031668186188\n",
      "Epoch: 1659 Loss: 0.2189050316810608\n",
      "Epoch: 1660 Loss: 0.21862643957138062\n",
      "Epoch: 1661 Loss: 0.2183363437652588\n",
      "Epoch: 1662 Loss: 0.21814271807670593\n",
      "Epoch: 1663 Loss: 0.21786099672317505\n",
      "Epoch: 1664 Loss: 0.21764308214187622\n",
      "Epoch: 1665 Loss: 0.21758165955543518\n",
      "Epoch: 1666 Loss: 0.21706433594226837\n",
      "Epoch: 1667 Loss: 0.21697458624839783\n",
      "Epoch: 1668 Loss: 0.21653559803962708\n",
      "Epoch: 1669 Loss: 0.21637850999832153\n",
      "Epoch: 1670 Loss: 0.21614474058151245\n",
      "Epoch: 1671 Loss: 0.21596401929855347\n",
      "Epoch: 1672 Loss: 0.21559184789657593\n",
      "Epoch: 1673 Loss: 0.21550866961479187\n",
      "Epoch: 1674 Loss: 0.2150568962097168\n",
      "Epoch: 1675 Loss: 0.21506957709789276\n",
      "Epoch: 1676 Loss: 0.21473966538906097\n",
      "Epoch: 1677 Loss: 0.2145262360572815\n",
      "Epoch: 1678 Loss: 0.21424663066864014\n",
      "Epoch: 1679 Loss: 0.2139337807893753\n",
      "Epoch: 1680 Loss: 0.21398988366127014\n",
      "Epoch: 1681 Loss: 0.21361063420772552\n",
      "Epoch: 1682 Loss: 0.2133137583732605\n",
      "Epoch: 1683 Loss: 0.2129981964826584\n",
      "Epoch: 1684 Loss: 0.21292172372341156\n",
      "Epoch: 1685 Loss: 0.21269194781780243\n",
      "Epoch: 1686 Loss: 0.2124817669391632\n",
      "Epoch: 1687 Loss: 0.21211251616477966\n",
      "Epoch: 1688 Loss: 0.21184535324573517\n",
      "Epoch: 1689 Loss: 0.21190765500068665\n",
      "Epoch: 1690 Loss: 0.21151511371135712\n",
      "Epoch: 1691 Loss: 0.21121296286582947\n",
      "Epoch: 1692 Loss: 0.2110932320356369\n",
      "Epoch: 1693 Loss: 0.21079719066619873\n",
      "Epoch: 1694 Loss: 0.2106383740901947\n",
      "Epoch: 1695 Loss: 0.2103096842765808\n",
      "Epoch: 1696 Loss: 0.21007362008094788\n",
      "Epoch: 1697 Loss: 0.20991283655166626\n",
      "Epoch: 1698 Loss: 0.20970502495765686\n",
      "Epoch: 1699 Loss: 0.20932604372501373\n",
      "Epoch: 1700 Loss: 0.20915649831295013\n",
      "Epoch: 1701 Loss: 0.2089606076478958\n",
      "Epoch: 1702 Loss: 0.20867541432380676\n",
      "Epoch: 1703 Loss: 0.20861630141735077\n",
      "Epoch: 1704 Loss: 0.20842337608337402\n",
      "Epoch: 1705 Loss: 0.2080928087234497\n",
      "Epoch: 1706 Loss: 0.2076866328716278\n",
      "Epoch: 1707 Loss: 0.20768165588378906\n",
      "Epoch: 1708 Loss: 0.2074737846851349\n",
      "Epoch: 1709 Loss: 0.20708037912845612\n",
      "Epoch: 1710 Loss: 0.20699958503246307\n",
      "Epoch: 1711 Loss: 0.2067829668521881\n",
      "Epoch: 1712 Loss: 0.20644034445285797\n",
      "Epoch: 1713 Loss: 0.20621080696582794\n",
      "Epoch: 1714 Loss: 0.20609807968139648\n",
      "Epoch: 1715 Loss: 0.20583179593086243\n",
      "Epoch: 1716 Loss: 0.2054896503686905\n",
      "Epoch: 1717 Loss: 0.20536454021930695\n",
      "Epoch: 1718 Loss: 0.20520709455013275\n",
      "Epoch: 1719 Loss: 0.2048986256122589\n",
      "Epoch: 1720 Loss: 0.20473580062389374\n",
      "Epoch: 1721 Loss: 0.20445838570594788\n",
      "Epoch: 1722 Loss: 0.2043219804763794\n",
      "Epoch: 1723 Loss: 0.2042069286108017\n",
      "Epoch: 1724 Loss: 0.20386064052581787\n",
      "Epoch: 1725 Loss: 0.20381736755371094\n",
      "Epoch: 1726 Loss: 0.20342785120010376\n",
      "Epoch: 1727 Loss: 0.2032468020915985\n",
      "Epoch: 1728 Loss: 0.20312705636024475\n",
      "Epoch: 1729 Loss: 0.20272718369960785\n",
      "Epoch: 1730 Loss: 0.20261868834495544\n",
      "Epoch: 1731 Loss: 0.20246994495391846\n",
      "Epoch: 1732 Loss: 0.20210382342338562\n",
      "Epoch: 1733 Loss: 0.20195484161376953\n",
      "Epoch: 1734 Loss: 0.20190492272377014\n",
      "Epoch: 1735 Loss: 0.20143592357635498\n",
      "Epoch: 1736 Loss: 0.20133842527866364\n",
      "Epoch: 1737 Loss: 0.20125587284564972\n",
      "Epoch: 1738 Loss: 0.2008640319108963\n",
      "Epoch: 1739 Loss: 0.20066779851913452\n",
      "Epoch: 1740 Loss: 0.2005038559436798\n",
      "Epoch: 1741 Loss: 0.20022325217723846\n",
      "Epoch: 1742 Loss: 0.20008371770381927\n",
      "Epoch: 1743 Loss: 0.19984059035778046\n",
      "Epoch: 1744 Loss: 0.19968588650226593\n",
      "Epoch: 1745 Loss: 0.19947370886802673\n",
      "Epoch: 1746 Loss: 0.19914327561855316\n",
      "Epoch: 1747 Loss: 0.19900059700012207\n",
      "Epoch: 1748 Loss: 0.1988096535205841\n",
      "Epoch: 1749 Loss: 0.19857360422611237\n",
      "Epoch: 1750 Loss: 0.19855865836143494\n",
      "Epoch: 1751 Loss: 0.19809113442897797\n",
      "Epoch: 1752 Loss: 0.19804039597511292\n",
      "Epoch: 1753 Loss: 0.19786818325519562\n",
      "Epoch: 1754 Loss: 0.19743093848228455\n",
      "Epoch: 1755 Loss: 0.1973700374364853\n",
      "Epoch: 1756 Loss: 0.19714155793190002\n",
      "Epoch: 1757 Loss: 0.19692842662334442\n",
      "Epoch: 1758 Loss: 0.19666996598243713\n",
      "Epoch: 1759 Loss: 0.19658929109573364\n",
      "Epoch: 1760 Loss: 0.19632357358932495\n",
      "Epoch: 1761 Loss: 0.1960080862045288\n",
      "Epoch: 1762 Loss: 0.19595560431480408\n",
      "Epoch: 1763 Loss: 0.1956443041563034\n",
      "Epoch: 1764 Loss: 0.1954701989889145\n",
      "Epoch: 1765 Loss: 0.19553247094154358\n",
      "Epoch: 1766 Loss: 0.19516082108020782\n",
      "Epoch: 1767 Loss: 0.19487392902374268\n",
      "Epoch: 1768 Loss: 0.19491522014141083\n",
      "Epoch: 1769 Loss: 0.19451212882995605\n",
      "Epoch: 1770 Loss: 0.1942727267742157\n",
      "Epoch: 1771 Loss: 0.19415345788002014\n",
      "Epoch: 1772 Loss: 0.19394837319850922\n",
      "Epoch: 1773 Loss: 0.19368809461593628\n",
      "Epoch: 1774 Loss: 0.19371332228183746\n",
      "Epoch: 1775 Loss: 0.193304181098938\n",
      "Epoch: 1776 Loss: 0.1931331902742386\n",
      "Epoch: 1777 Loss: 0.19299101829528809\n",
      "Epoch: 1778 Loss: 0.19277653098106384\n",
      "Epoch: 1779 Loss: 0.19247284531593323\n",
      "Epoch: 1780 Loss: 0.19239307940006256\n",
      "Epoch: 1781 Loss: 0.19213536381721497\n",
      "Epoch: 1782 Loss: 0.19207406044006348\n",
      "Epoch: 1783 Loss: 0.19169647991657257\n",
      "Epoch: 1784 Loss: 0.19165459275245667\n",
      "Epoch: 1785 Loss: 0.1913493126630783\n",
      "Epoch: 1786 Loss: 0.19113850593566895\n",
      "Epoch: 1787 Loss: 0.1910400092601776\n",
      "Epoch: 1788 Loss: 0.19093656539916992\n",
      "Epoch: 1789 Loss: 0.19056886434555054\n",
      "Epoch: 1790 Loss: 0.1904507279396057\n",
      "Epoch: 1791 Loss: 0.19017264246940613\n",
      "Epoch: 1792 Loss: 0.19001229107379913\n",
      "Epoch: 1793 Loss: 0.18990883231163025\n",
      "Epoch: 1794 Loss: 0.1897384524345398\n",
      "Epoch: 1795 Loss: 0.18942327797412872\n",
      "Epoch: 1796 Loss: 0.18946751952171326\n",
      "Epoch: 1797 Loss: 0.18910594284534454\n",
      "Epoch: 1798 Loss: 0.18885117769241333\n",
      "Epoch: 1799 Loss: 0.18885573744773865\n",
      "Epoch: 1800 Loss: 0.18855461478233337\n",
      "Epoch: 1801 Loss: 0.18839973211288452\n",
      "Epoch: 1802 Loss: 0.1881677806377411\n",
      "Epoch: 1803 Loss: 0.1879487931728363\n",
      "Epoch: 1804 Loss: 0.18774756789207458\n",
      "Epoch: 1805 Loss: 0.1876593977212906\n",
      "Epoch: 1806 Loss: 0.1874494105577469\n",
      "Epoch: 1807 Loss: 0.18715588748455048\n",
      "Epoch: 1808 Loss: 0.1871163696050644\n",
      "Epoch: 1809 Loss: 0.18676255643367767\n",
      "Epoch: 1810 Loss: 0.1867331564426422\n",
      "Epoch: 1811 Loss: 0.18644505739212036\n",
      "Epoch: 1812 Loss: 0.1863548904657364\n",
      "Epoch: 1813 Loss: 0.1861000955104828\n",
      "Epoch: 1814 Loss: 0.1859178990125656\n",
      "Epoch: 1815 Loss: 0.18578435480594635\n",
      "Epoch: 1816 Loss: 0.18551965057849884\n",
      "Epoch: 1817 Loss: 0.1853707730770111\n",
      "Epoch: 1818 Loss: 0.18522045016288757\n",
      "Epoch: 1819 Loss: 0.1849566251039505\n",
      "Epoch: 1820 Loss: 0.18474167585372925\n",
      "Epoch: 1821 Loss: 0.18475782871246338\n",
      "Epoch: 1822 Loss: 0.18437059223651886\n",
      "Epoch: 1823 Loss: 0.18425966799259186\n",
      "Epoch: 1824 Loss: 0.1841723769903183\n",
      "Epoch: 1825 Loss: 0.18392713367938995\n",
      "Epoch: 1826 Loss: 0.18371470272541046\n",
      "Epoch: 1827 Loss: 0.1835126280784607\n",
      "Epoch: 1828 Loss: 0.183268740773201\n",
      "Epoch: 1829 Loss: 0.18318842351436615\n",
      "Epoch: 1830 Loss: 0.18305525183677673\n",
      "Epoch: 1831 Loss: 0.18278911709785461\n",
      "Epoch: 1832 Loss: 0.18247489631175995\n",
      "Epoch: 1833 Loss: 0.18255412578582764\n",
      "Epoch: 1834 Loss: 0.18220041692256927\n",
      "Epoch: 1835 Loss: 0.18216447532176971\n",
      "Epoch: 1836 Loss: 0.1819448471069336\n",
      "Epoch: 1837 Loss: 0.18177533149719238\n",
      "Epoch: 1838 Loss: 0.1815531700849533\n",
      "Epoch: 1839 Loss: 0.18139415979385376\n",
      "Epoch: 1840 Loss: 0.1811562478542328\n",
      "Epoch: 1841 Loss: 0.1810426115989685\n",
      "Epoch: 1842 Loss: 0.18081118166446686\n",
      "Epoch: 1843 Loss: 0.18066298961639404\n",
      "Epoch: 1844 Loss: 0.1804254651069641\n",
      "Epoch: 1845 Loss: 0.1803491860628128\n",
      "Epoch: 1846 Loss: 0.18015089631080627\n",
      "Epoch: 1847 Loss: 0.17992228269577026\n",
      "Epoch: 1848 Loss: 0.17979411780834198\n",
      "Epoch: 1849 Loss: 0.17959262430667877\n",
      "Epoch: 1850 Loss: 0.17943842709064484\n",
      "Epoch: 1851 Loss: 0.17933261394500732\n",
      "Epoch: 1852 Loss: 0.17918482422828674\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1853 Loss: 0.17884890735149384\n",
      "Epoch: 1854 Loss: 0.17870470881462097\n",
      "Epoch: 1855 Loss: 0.1785806566476822\n",
      "Epoch: 1856 Loss: 0.17833738029003143\n",
      "Epoch: 1857 Loss: 0.1781667023897171\n",
      "Epoch: 1858 Loss: 0.17819958925247192\n",
      "Epoch: 1859 Loss: 0.17786657810211182\n",
      "Epoch: 1860 Loss: 0.17756561934947968\n",
      "Epoch: 1861 Loss: 0.1775447428226471\n",
      "Epoch: 1862 Loss: 0.17735233902931213\n",
      "Epoch: 1863 Loss: 0.17718777060508728\n",
      "Epoch: 1864 Loss: 0.17704997956752777\n",
      "Epoch: 1865 Loss: 0.17674113810062408\n",
      "Epoch: 1866 Loss: 0.17658349871635437\n",
      "Epoch: 1867 Loss: 0.17650727927684784\n",
      "Epoch: 1868 Loss: 0.17637978494167328\n",
      "Epoch: 1869 Loss: 0.17611035704612732\n",
      "Epoch: 1870 Loss: 0.17586664855480194\n",
      "Epoch: 1871 Loss: 0.17589786648750305\n",
      "Epoch: 1872 Loss: 0.17563015222549438\n",
      "Epoch: 1873 Loss: 0.17554159462451935\n",
      "Epoch: 1874 Loss: 0.17527280747890472\n",
      "Epoch: 1875 Loss: 0.17503415048122406\n",
      "Epoch: 1876 Loss: 0.17505910992622375\n",
      "Epoch: 1877 Loss: 0.17476093769073486\n",
      "Epoch: 1878 Loss: 0.17462843656539917\n",
      "Epoch: 1879 Loss: 0.17443621158599854\n",
      "Epoch: 1880 Loss: 0.17432154715061188\n",
      "Epoch: 1881 Loss: 0.17403671145439148\n",
      "Epoch: 1882 Loss: 0.1739906370639801\n",
      "Epoch: 1883 Loss: 0.17373701930046082\n",
      "Epoch: 1884 Loss: 0.17366768419742584\n",
      "Epoch: 1885 Loss: 0.17336702346801758\n",
      "Epoch: 1886 Loss: 0.1733820140361786\n",
      "Epoch: 1887 Loss: 0.17303892970085144\n",
      "Epoch: 1888 Loss: 0.17292830348014832\n",
      "Epoch: 1889 Loss: 0.17290091514587402\n",
      "Epoch: 1890 Loss: 0.1726575344800949\n",
      "Epoch: 1891 Loss: 0.17238456010818481\n",
      "Epoch: 1892 Loss: 0.17234483361244202\n",
      "Epoch: 1893 Loss: 0.17206642031669617\n",
      "Epoch: 1894 Loss: 0.17194443941116333\n",
      "Epoch: 1895 Loss: 0.17192518711090088\n",
      "Epoch: 1896 Loss: 0.1715741753578186\n",
      "Epoch: 1897 Loss: 0.17141105234622955\n",
      "Epoch: 1898 Loss: 0.1713477224111557\n",
      "Epoch: 1899 Loss: 0.17112629115581512\n",
      "Epoch: 1900 Loss: 0.17095132172107697\n",
      "Epoch: 1901 Loss: 0.17083966732025146\n",
      "Epoch: 1902 Loss: 0.1706770658493042\n",
      "Epoch: 1903 Loss: 0.1704627275466919\n",
      "Epoch: 1904 Loss: 0.1704278439283371\n",
      "Epoch: 1905 Loss: 0.17028799653053284\n",
      "Epoch: 1906 Loss: 0.16995586454868317\n",
      "Epoch: 1907 Loss: 0.16992318630218506\n",
      "Epoch: 1908 Loss: 0.16968268156051636\n",
      "Epoch: 1909 Loss: 0.16958971321582794\n",
      "Epoch: 1910 Loss: 0.16938842833042145\n",
      "Epoch: 1911 Loss: 0.16929864883422852\n",
      "Epoch: 1912 Loss: 0.16916055977344513\n",
      "Epoch: 1913 Loss: 0.16887114942073822\n",
      "Epoch: 1914 Loss: 0.16879861056804657\n",
      "Epoch: 1915 Loss: 0.1685674786567688\n",
      "Epoch: 1916 Loss: 0.16855208575725555\n",
      "Epoch: 1917 Loss: 0.16839301586151123\n",
      "Epoch: 1918 Loss: 0.168160080909729\n",
      "Epoch: 1919 Loss: 0.16790097951889038\n",
      "Epoch: 1920 Loss: 0.16785383224487305\n",
      "Epoch: 1921 Loss: 0.1676185429096222\n",
      "Epoch: 1922 Loss: 0.1676006019115448\n",
      "Epoch: 1923 Loss: 0.16733196377754211\n",
      "Epoch: 1924 Loss: 0.16714346408843994\n",
      "Epoch: 1925 Loss: 0.1670195311307907\n",
      "Epoch: 1926 Loss: 0.16683687269687653\n",
      "Epoch: 1927 Loss: 0.16672776639461517\n",
      "Epoch: 1928 Loss: 0.16652804613113403\n",
      "Epoch: 1929 Loss: 0.16629303991794586\n",
      "Epoch: 1930 Loss: 0.16622576117515564\n",
      "Epoch: 1931 Loss: 0.16607289016246796\n",
      "Epoch: 1932 Loss: 0.16588477790355682\n",
      "Epoch: 1933 Loss: 0.16575336456298828\n",
      "Epoch: 1934 Loss: 0.16561585664749146\n",
      "Epoch: 1935 Loss: 0.165415421128273\n",
      "Epoch: 1936 Loss: 0.16528505086898804\n",
      "Epoch: 1937 Loss: 0.16513441503047943\n",
      "Epoch: 1938 Loss: 0.16500310599803925\n",
      "Epoch: 1939 Loss: 0.16490116715431213\n",
      "Epoch: 1940 Loss: 0.16465994715690613\n",
      "Epoch: 1941 Loss: 0.1644248068332672\n",
      "Epoch: 1942 Loss: 0.16446417570114136\n",
      "Epoch: 1943 Loss: 0.16427017748355865\n",
      "Epoch: 1944 Loss: 0.16400904953479767\n",
      "Epoch: 1945 Loss: 0.1639009714126587\n",
      "Epoch: 1946 Loss: 0.16377398371696472\n",
      "Epoch: 1947 Loss: 0.16354499757289886\n",
      "Epoch: 1948 Loss: 0.1635456383228302\n",
      "Epoch: 1949 Loss: 0.16338251531124115\n",
      "Epoch: 1950 Loss: 0.1632671058177948\n",
      "Epoch: 1951 Loss: 0.16295842826366425\n",
      "Epoch: 1952 Loss: 0.1628691554069519\n",
      "Epoch: 1953 Loss: 0.16266918182373047\n",
      "Epoch: 1954 Loss: 0.1625988483428955\n",
      "Epoch: 1955 Loss: 0.16246002912521362\n",
      "Epoch: 1956 Loss: 0.16234728693962097\n",
      "Epoch: 1957 Loss: 0.16206996142864227\n",
      "Epoch: 1958 Loss: 0.16200706362724304\n",
      "Epoch: 1959 Loss: 0.16186152398586273\n",
      "Epoch: 1960 Loss: 0.16165919601917267\n",
      "Epoch: 1961 Loss: 0.1615443080663681\n",
      "Epoch: 1962 Loss: 0.16140758991241455\n",
      "Epoch: 1963 Loss: 0.16117319464683533\n",
      "Epoch: 1964 Loss: 0.16110201179981232\n",
      "Epoch: 1965 Loss: 0.16098493337631226\n",
      "Epoch: 1966 Loss: 0.1608179211616516\n",
      "Epoch: 1967 Loss: 0.16056086122989655\n",
      "Epoch: 1968 Loss: 0.1605926752090454\n",
      "Epoch: 1969 Loss: 0.16038136184215546\n",
      "Epoch: 1970 Loss: 0.1601875126361847\n",
      "Epoch: 1971 Loss: 0.160106360912323\n",
      "Epoch: 1972 Loss: 0.1599302887916565\n",
      "Epoch: 1973 Loss: 0.15980495512485504\n",
      "Epoch: 1974 Loss: 0.15967784821987152\n",
      "Epoch: 1975 Loss: 0.15952923893928528\n",
      "Epoch: 1976 Loss: 0.15934932231903076\n",
      "Epoch: 1977 Loss: 0.1592247039079666\n",
      "Epoch: 1978 Loss: 0.15905968844890594\n",
      "Epoch: 1979 Loss: 0.15893025696277618\n",
      "Epoch: 1980 Loss: 0.15891185402870178\n",
      "Epoch: 1981 Loss: 0.15868373215198517\n",
      "Epoch: 1982 Loss: 0.15846700966358185\n",
      "Epoch: 1983 Loss: 0.15835192799568176\n",
      "Epoch: 1984 Loss: 0.15815980732440948\n",
      "Epoch: 1985 Loss: 0.1581583023071289\n",
      "Epoch: 1986 Loss: 0.15798351168632507\n",
      "Epoch: 1987 Loss: 0.157756045460701\n",
      "Epoch: 1988 Loss: 0.15767262876033783\n",
      "Epoch: 1989 Loss: 0.15750117599964142\n",
      "Epoch: 1990 Loss: 0.15736229717731476\n",
      "Epoch: 1991 Loss: 0.15723109245300293\n",
      "Epoch: 1992 Loss: 0.15708675980567932\n",
      "Epoch: 1993 Loss: 0.15695513784885406\n",
      "Epoch: 1994 Loss: 0.15680477023124695\n",
      "Epoch: 1995 Loss: 0.15673209726810455\n",
      "Epoch: 1996 Loss: 0.15651462972164154\n",
      "Epoch: 1997 Loss: 0.15640789270401\n",
      "Epoch: 1998 Loss: 0.1562877595424652\n",
      "Epoch: 1999 Loss: 0.1560904085636139\n",
      "Epoch: 2000 Loss: 0.15595375001430511\n",
      "Epoch: 2001 Loss: 0.15591201186180115\n",
      "Epoch: 2002 Loss: 0.15570077300071716\n",
      "Epoch: 2003 Loss: 0.15552659332752228\n",
      "Epoch: 2004 Loss: 0.15547612309455872\n",
      "Epoch: 2005 Loss: 0.15524627268314362\n",
      "Epoch: 2006 Loss: 0.1551646739244461\n",
      "Epoch: 2007 Loss: 0.15505357086658478\n",
      "Epoch: 2008 Loss: 0.15491929650306702\n",
      "Epoch: 2009 Loss: 0.1547301709651947\n",
      "Epoch: 2010 Loss: 0.1546023190021515\n",
      "Epoch: 2011 Loss: 0.1544627696275711\n",
      "Epoch: 2012 Loss: 0.1542579084634781\n",
      "Epoch: 2013 Loss: 0.15421923995018005\n",
      "Epoch: 2014 Loss: 0.15404236316680908\n",
      "Epoch: 2015 Loss: 0.15390123426914215\n",
      "Epoch: 2016 Loss: 0.1537243127822876\n",
      "Epoch: 2017 Loss: 0.15370596945285797\n",
      "Epoch: 2018 Loss: 0.15347182750701904\n",
      "Epoch: 2019 Loss: 0.1533169001340866\n",
      "Epoch: 2020 Loss: 0.1532401293516159\n",
      "Epoch: 2021 Loss: 0.1531054526567459\n",
      "Epoch: 2022 Loss: 0.1530386209487915\n",
      "Epoch: 2023 Loss: 0.15286774933338165\n",
      "Epoch: 2024 Loss: 0.15264521539211273\n",
      "Epoch: 2025 Loss: 0.1525268703699112\n",
      "Epoch: 2026 Loss: 0.15250174701213837\n",
      "Epoch: 2027 Loss: 0.15235359966754913\n",
      "Epoch: 2028 Loss: 0.1521029770374298\n",
      "Epoch: 2029 Loss: 0.151947483420372\n",
      "Epoch: 2030 Loss: 0.15186041593551636\n",
      "Epoch: 2031 Loss: 0.15173117816448212\n",
      "Epoch: 2032 Loss: 0.15162353217601776\n",
      "Epoch: 2033 Loss: 0.15147194266319275\n",
      "Epoch: 2034 Loss: 0.15132197737693787\n",
      "Epoch: 2035 Loss: 0.15121281147003174\n",
      "Epoch: 2036 Loss: 0.1510867327451706\n",
      "Epoch: 2037 Loss: 0.1509471982717514\n",
      "Epoch: 2038 Loss: 0.15081588923931122\n",
      "Epoch: 2039 Loss: 0.15064220130443573\n",
      "Epoch: 2040 Loss: 0.15059170126914978\n",
      "Epoch: 2041 Loss: 0.15045896172523499\n",
      "Epoch: 2042 Loss: 0.15031656622886658\n",
      "Epoch: 2043 Loss: 0.15010499954223633\n",
      "Epoch: 2044 Loss: 0.1500505954027176\n",
      "Epoch: 2045 Loss: 0.15001612901687622\n",
      "Epoch: 2046 Loss: 0.14973777532577515\n",
      "Epoch: 2047 Loss: 0.14965036511421204\n",
      "Epoch: 2048 Loss: 0.14945045113563538\n",
      "Epoch: 2049 Loss: 0.1493753343820572\n",
      "Epoch: 2050 Loss: 0.1492544412612915\n",
      "Epoch: 2051 Loss: 0.1491100937128067\n",
      "Epoch: 2052 Loss: 0.1489887535572052\n",
      "Epoch: 2053 Loss: 0.14885839819908142\n",
      "Epoch: 2054 Loss: 0.14887356758117676\n",
      "Epoch: 2055 Loss: 0.14857029914855957\n",
      "Epoch: 2056 Loss: 0.14853888750076294\n",
      "Epoch: 2057 Loss: 0.148344948887825\n",
      "Epoch: 2058 Loss: 0.14822722971439362\n",
      "Epoch: 2059 Loss: 0.1480407565832138\n",
      "Epoch: 2060 Loss: 0.1479434370994568\n",
      "Epoch: 2061 Loss: 0.14783607423305511\n",
      "Epoch: 2062 Loss: 0.14772899448871613\n",
      "Epoch: 2063 Loss: 0.1476249098777771\n",
      "Epoch: 2064 Loss: 0.1474628448486328\n",
      "Epoch: 2065 Loss: 0.14728668332099915\n",
      "Epoch: 2066 Loss: 0.14724978804588318\n",
      "Epoch: 2067 Loss: 0.14705400168895721\n",
      "Epoch: 2068 Loss: 0.14699505269527435\n",
      "Epoch: 2069 Loss: 0.1467762589454651\n",
      "Epoch: 2070 Loss: 0.14672712981700897\n",
      "Epoch: 2071 Loss: 0.14652429521083832\n",
      "Epoch: 2072 Loss: 0.14640973508358002\n",
      "Epoch: 2073 Loss: 0.14632484316825867\n",
      "Epoch: 2074 Loss: 0.14630575478076935\n",
      "Epoch: 2075 Loss: 0.1460210084915161\n",
      "Epoch: 2076 Loss: 0.1459614783525467\n",
      "Epoch: 2077 Loss: 0.1458507776260376\n",
      "Epoch: 2078 Loss: 0.14570924639701843\n",
      "Epoch: 2079 Loss: 0.14550785720348358\n",
      "Epoch: 2080 Loss: 0.14543302357196808\n",
      "Epoch: 2081 Loss: 0.14529186487197876\n",
      "Epoch: 2082 Loss: 0.14516879618167877\n",
      "Epoch: 2083 Loss: 0.1450660675764084\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2084 Loss: 0.14500771462917328\n",
      "Epoch: 2085 Loss: 0.14479458332061768\n",
      "Epoch: 2086 Loss: 0.14463625848293304\n",
      "Epoch: 2087 Loss: 0.14454081654548645\n",
      "Epoch: 2088 Loss: 0.1444566398859024\n",
      "Epoch: 2089 Loss: 0.14431659877300262\n",
      "Epoch: 2090 Loss: 0.14415238797664642\n",
      "Epoch: 2091 Loss: 0.14405903220176697\n",
      "Epoch: 2092 Loss: 0.1438966691493988\n",
      "Epoch: 2093 Loss: 0.1437695324420929\n",
      "Epoch: 2094 Loss: 0.14366938173770905\n",
      "Epoch: 2095 Loss: 0.1435183435678482\n",
      "Epoch: 2096 Loss: 0.14344394207000732\n",
      "Epoch: 2097 Loss: 0.1433134824037552\n",
      "Epoch: 2098 Loss: 0.1431715190410614\n",
      "Epoch: 2099 Loss: 0.1430383324623108\n",
      "Epoch: 2100 Loss: 0.14290861785411835\n",
      "Epoch: 2101 Loss: 0.14282990992069244\n",
      "Epoch: 2102 Loss: 0.14267833530902863\n",
      "Epoch: 2103 Loss: 0.14252060651779175\n",
      "Epoch: 2104 Loss: 0.14244794845581055\n",
      "Epoch: 2105 Loss: 0.1423359513282776\n",
      "Epoch: 2106 Loss: 0.1421981006860733\n",
      "Epoch: 2107 Loss: 0.142067551612854\n",
      "Epoch: 2108 Loss: 0.14189712703227997\n",
      "Epoch: 2109 Loss: 0.14187049865722656\n",
      "Epoch: 2110 Loss: 0.1417236179113388\n",
      "Epoch: 2111 Loss: 0.1416196972131729\n",
      "Epoch: 2112 Loss: 0.14147703349590302\n",
      "Epoch: 2113 Loss: 0.14146281778812408\n",
      "Epoch: 2114 Loss: 0.141211599111557\n",
      "Epoch: 2115 Loss: 0.14110133051872253\n",
      "Epoch: 2116 Loss: 0.14093978703022003\n",
      "Epoch: 2117 Loss: 0.1408991664648056\n",
      "Epoch: 2118 Loss: 0.14081187546253204\n",
      "Epoch: 2119 Loss: 0.1406192034482956\n",
      "Epoch: 2120 Loss: 0.1405242681503296\n",
      "Epoch: 2121 Loss: 0.1403992772102356\n",
      "Epoch: 2122 Loss: 0.14029483497142792\n",
      "Epoch: 2123 Loss: 0.14013487100601196\n",
      "Epoch: 2124 Loss: 0.1400863081216812\n",
      "Epoch: 2125 Loss: 0.13986553251743317\n",
      "Epoch: 2126 Loss: 0.13977056741714478\n",
      "Epoch: 2127 Loss: 0.13967415690422058\n",
      "Epoch: 2128 Loss: 0.13953085243701935\n",
      "Epoch: 2129 Loss: 0.13941046595573425\n",
      "Epoch: 2130 Loss: 0.13933566212654114\n",
      "Epoch: 2131 Loss: 0.139188751578331\n",
      "Epoch: 2132 Loss: 0.1390685737133026\n",
      "Epoch: 2133 Loss: 0.13904176652431488\n",
      "Epoch: 2134 Loss: 0.1388338953256607\n",
      "Epoch: 2135 Loss: 0.13873611390590668\n",
      "Epoch: 2136 Loss: 0.13860952854156494\n",
      "Epoch: 2137 Loss: 0.13848111033439636\n",
      "Epoch: 2138 Loss: 0.1383615881204605\n",
      "Epoch: 2139 Loss: 0.13821373879909515\n",
      "Epoch: 2140 Loss: 0.13813914358615875\n",
      "Epoch: 2141 Loss: 0.1380343735218048\n",
      "Epoch: 2142 Loss: 0.13792802393436432\n",
      "Epoch: 2143 Loss: 0.1378193348646164\n",
      "Epoch: 2144 Loss: 0.13764135539531708\n",
      "Epoch: 2145 Loss: 0.13753293454647064\n",
      "Epoch: 2146 Loss: 0.13742581009864807\n",
      "Epoch: 2147 Loss: 0.137301504611969\n",
      "Epoch: 2148 Loss: 0.13718347251415253\n",
      "Epoch: 2149 Loss: 0.13712668418884277\n",
      "Epoch: 2150 Loss: 0.1371253877878189\n",
      "Epoch: 2151 Loss: 0.13690297305583954\n",
      "Epoch: 2152 Loss: 0.1367322951555252\n",
      "Epoch: 2153 Loss: 0.13658805191516876\n",
      "Epoch: 2154 Loss: 0.1365414410829544\n",
      "Epoch: 2155 Loss: 0.13637393712997437\n",
      "Epoch: 2156 Loss: 0.13643954694271088\n",
      "Epoch: 2157 Loss: 0.13620522618293762\n",
      "Epoch: 2158 Loss: 0.13606718182563782\n",
      "Epoch: 2159 Loss: 0.13593080639839172\n",
      "Epoch: 2160 Loss: 0.135867178440094\n",
      "Epoch: 2161 Loss: 0.13584111630916595\n",
      "Epoch: 2162 Loss: 0.13561053574085236\n",
      "Epoch: 2163 Loss: 0.13548290729522705\n",
      "Epoch: 2164 Loss: 0.1353304386138916\n",
      "Epoch: 2165 Loss: 0.13526400923728943\n",
      "Epoch: 2166 Loss: 0.13515540957450867\n",
      "Epoch: 2167 Loss: 0.13497669994831085\n",
      "Epoch: 2168 Loss: 0.1349516212940216\n",
      "Epoch: 2169 Loss: 0.13479331135749817\n",
      "Epoch: 2170 Loss: 0.13468825817108154\n",
      "Epoch: 2171 Loss: 0.13456732034683228\n",
      "Epoch: 2172 Loss: 0.134460911154747\n",
      "Epoch: 2173 Loss: 0.1343681961297989\n",
      "Epoch: 2174 Loss: 0.13419702649116516\n",
      "Epoch: 2175 Loss: 0.1341284215450287\n",
      "Epoch: 2176 Loss: 0.1340126395225525\n",
      "Epoch: 2177 Loss: 0.13392619788646698\n",
      "Epoch: 2178 Loss: 0.13378863036632538\n",
      "Epoch: 2179 Loss: 0.13372103869915009\n",
      "Epoch: 2180 Loss: 0.13353821635246277\n",
      "Epoch: 2181 Loss: 0.13344275951385498\n",
      "Epoch: 2182 Loss: 0.13332252204418182\n",
      "Epoch: 2183 Loss: 0.13323894143104553\n",
      "Epoch: 2184 Loss: 0.13316676020622253\n",
      "Epoch: 2185 Loss: 0.13300007581710815\n",
      "Epoch: 2186 Loss: 0.1328650861978531\n",
      "Epoch: 2187 Loss: 0.13283732533454895\n",
      "Epoch: 2188 Loss: 0.13269829750061035\n",
      "Epoch: 2189 Loss: 0.13257993757724762\n",
      "Epoch: 2190 Loss: 0.13246263563632965\n",
      "Epoch: 2191 Loss: 0.1323983371257782\n",
      "Epoch: 2192 Loss: 0.13225576281547546\n",
      "Epoch: 2193 Loss: 0.13214515149593353\n",
      "Epoch: 2194 Loss: 0.1320744752883911\n",
      "Epoch: 2195 Loss: 0.13192954659461975\n",
      "Epoch: 2196 Loss: 0.13180257380008698\n",
      "Epoch: 2197 Loss: 0.13169114291667938\n",
      "Epoch: 2198 Loss: 0.13161495327949524\n",
      "Epoch: 2199 Loss: 0.13150052726268768\n",
      "Epoch: 2200 Loss: 0.1313961297273636\n",
      "Epoch: 2201 Loss: 0.13126052916049957\n",
      "Epoch: 2202 Loss: 0.1312159299850464\n",
      "Epoch: 2203 Loss: 0.13105455040931702\n",
      "Epoch: 2204 Loss: 0.13098491728305817\n",
      "Epoch: 2205 Loss: 0.13088083267211914\n",
      "Epoch: 2206 Loss: 0.1307799518108368\n",
      "Epoch: 2207 Loss: 0.130622997879982\n",
      "Epoch: 2208 Loss: 0.13050577044487\n",
      "Epoch: 2209 Loss: 0.13045255839824677\n",
      "Epoch: 2210 Loss: 0.13030222058296204\n",
      "Epoch: 2211 Loss: 0.1302012950181961\n",
      "Epoch: 2212 Loss: 0.13011261820793152\n",
      "Epoch: 2213 Loss: 0.13001905381679535\n",
      "Epoch: 2214 Loss: 0.1298697590827942\n",
      "Epoch: 2215 Loss: 0.12979529798030853\n",
      "Epoch: 2216 Loss: 0.12967617809772491\n",
      "Epoch: 2217 Loss: 0.12958401441574097\n",
      "Epoch: 2218 Loss: 0.12952761352062225\n",
      "Epoch: 2219 Loss: 0.12938450276851654\n",
      "Epoch: 2220 Loss: 0.12922394275665283\n",
      "Epoch: 2221 Loss: 0.12920431792736053\n",
      "Epoch: 2222 Loss: 0.1290217787027359\n",
      "Epoch: 2223 Loss: 0.12895691394805908\n",
      "Epoch: 2224 Loss: 0.12881144881248474\n",
      "Epoch: 2225 Loss: 0.12877732515335083\n",
      "Epoch: 2226 Loss: 0.1285945177078247\n",
      "Epoch: 2227 Loss: 0.128536194562912\n",
      "Epoch: 2228 Loss: 0.12842972576618195\n",
      "Epoch: 2229 Loss: 0.12835067510604858\n",
      "Epoch: 2230 Loss: 0.12822534143924713\n",
      "Epoch: 2231 Loss: 0.12809757888317108\n",
      "Epoch: 2232 Loss: 0.12799683213233948\n",
      "Epoch: 2233 Loss: 0.12791000306606293\n",
      "Epoch: 2234 Loss: 0.12779481709003448\n",
      "Epoch: 2235 Loss: 0.12773092091083527\n",
      "Epoch: 2236 Loss: 0.12761448323726654\n",
      "Epoch: 2237 Loss: 0.12745878100395203\n",
      "Epoch: 2238 Loss: 0.12733063101768494\n",
      "Epoch: 2239 Loss: 0.12726326286792755\n",
      "Epoch: 2240 Loss: 0.12716057896614075\n",
      "Epoch: 2241 Loss: 0.1270776242017746\n",
      "Epoch: 2242 Loss: 0.12692499160766602\n",
      "Epoch: 2243 Loss: 0.1268630474805832\n",
      "Epoch: 2244 Loss: 0.1267419457435608\n",
      "Epoch: 2245 Loss: 0.12664741277694702\n",
      "Epoch: 2246 Loss: 0.1265416443347931\n",
      "Epoch: 2247 Loss: 0.12644034624099731\n",
      "Epoch: 2248 Loss: 0.12636809051036835\n",
      "Epoch: 2249 Loss: 0.126177579164505\n",
      "Epoch: 2250 Loss: 0.12605290114879608\n",
      "Epoch: 2251 Loss: 0.12600122392177582\n",
      "Epoch: 2252 Loss: 0.125886932015419\n",
      "Epoch: 2253 Loss: 0.12579914927482605\n",
      "Epoch: 2254 Loss: 0.12571170926094055\n",
      "Epoch: 2255 Loss: 0.1256406307220459\n",
      "Epoch: 2256 Loss: 0.12550176680088043\n",
      "Epoch: 2257 Loss: 0.12539587914943695\n",
      "Epoch: 2258 Loss: 0.12529081106185913\n",
      "Epoch: 2259 Loss: 0.12517334520816803\n",
      "Epoch: 2260 Loss: 0.1251039355993271\n",
      "Epoch: 2261 Loss: 0.12501144409179688\n",
      "Epoch: 2262 Loss: 0.12491472065448761\n",
      "Epoch: 2263 Loss: 0.12474685907363892\n",
      "Epoch: 2264 Loss: 0.12470056861639023\n",
      "Epoch: 2265 Loss: 0.12460631877183914\n",
      "Epoch: 2266 Loss: 0.12445273250341415\n",
      "Epoch: 2267 Loss: 0.12439730018377304\n",
      "Epoch: 2268 Loss: 0.12431036680936813\n",
      "Epoch: 2269 Loss: 0.12417563050985336\n",
      "Epoch: 2270 Loss: 0.12409457564353943\n",
      "Epoch: 2271 Loss: 0.12400169670581818\n",
      "Epoch: 2272 Loss: 0.12388542294502258\n",
      "Epoch: 2273 Loss: 0.12376274913549423\n",
      "Epoch: 2274 Loss: 0.12368330359458923\n",
      "Epoch: 2275 Loss: 0.12359648942947388\n",
      "Epoch: 2276 Loss: 0.12347058206796646\n",
      "Epoch: 2277 Loss: 0.12338166683912277\n",
      "Epoch: 2278 Loss: 0.12331751734018326\n",
      "Epoch: 2279 Loss: 0.12321401387453079\n",
      "Epoch: 2280 Loss: 0.12310442328453064\n",
      "Epoch: 2281 Loss: 0.12298476696014404\n",
      "Epoch: 2282 Loss: 0.12295713275671005\n",
      "Epoch: 2283 Loss: 0.12279562652111053\n",
      "Epoch: 2284 Loss: 0.12272531539201736\n",
      "Epoch: 2285 Loss: 0.12259741127490997\n",
      "Epoch: 2286 Loss: 0.12248094379901886\n",
      "Epoch: 2287 Loss: 0.12240269035100937\n",
      "Epoch: 2288 Loss: 0.12229373306035995\n",
      "Epoch: 2289 Loss: 0.12225931882858276\n",
      "Epoch: 2290 Loss: 0.12210360169410706\n",
      "Epoch: 2291 Loss: 0.12201957404613495\n",
      "Epoch: 2292 Loss: 0.12198562920093536\n",
      "Epoch: 2293 Loss: 0.1219174936413765\n",
      "Epoch: 2294 Loss: 0.12171939015388489\n",
      "Epoch: 2295 Loss: 0.12159896641969681\n",
      "Epoch: 2296 Loss: 0.12152651697397232\n",
      "Epoch: 2297 Loss: 0.121456079185009\n",
      "Epoch: 2298 Loss: 0.12134332209825516\n",
      "Epoch: 2299 Loss: 0.12124910205602646\n",
      "Epoch: 2300 Loss: 0.12116345763206482\n",
      "Epoch: 2301 Loss: 0.12102154642343521\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2302 Loss: 0.12094835191965103\n",
      "Epoch: 2303 Loss: 0.12082891911268234\n",
      "Epoch: 2304 Loss: 0.12072534114122391\n",
      "Epoch: 2305 Loss: 0.12065427005290985\n",
      "Epoch: 2306 Loss: 0.12057747691869736\n",
      "Epoch: 2307 Loss: 0.12043754756450653\n",
      "Epoch: 2308 Loss: 0.12038204073905945\n",
      "Epoch: 2309 Loss: 0.12027539312839508\n",
      "Epoch: 2310 Loss: 0.1201474666595459\n",
      "Epoch: 2311 Loss: 0.12011440843343735\n",
      "Epoch: 2312 Loss: 0.12001331150531769\n",
      "Epoch: 2313 Loss: 0.11988004297018051\n",
      "Epoch: 2314 Loss: 0.11981222033500671\n",
      "Epoch: 2315 Loss: 0.11968223005533218\n",
      "Epoch: 2316 Loss: 0.11963416635990143\n",
      "Epoch: 2317 Loss: 0.11959797888994217\n",
      "Epoch: 2318 Loss: 0.1194261908531189\n",
      "Epoch: 2319 Loss: 0.11933813244104385\n",
      "Epoch: 2320 Loss: 0.11927976459264755\n",
      "Epoch: 2321 Loss: 0.11914707720279694\n",
      "Epoch: 2322 Loss: 0.11906152963638306\n",
      "Epoch: 2323 Loss: 0.11897692829370499\n",
      "Epoch: 2324 Loss: 0.11889945715665817\n",
      "Epoch: 2325 Loss: 0.11876726150512695\n",
      "Epoch: 2326 Loss: 0.11878830939531326\n",
      "Epoch: 2327 Loss: 0.11858037859201431\n",
      "Epoch: 2328 Loss: 0.1184835433959961\n",
      "Epoch: 2329 Loss: 0.11838492751121521\n",
      "Epoch: 2330 Loss: 0.11830653250217438\n",
      "Epoch: 2331 Loss: 0.11821296066045761\n",
      "Epoch: 2332 Loss: 0.11811916530132294\n",
      "Epoch: 2333 Loss: 0.11802039295434952\n",
      "Epoch: 2334 Loss: 0.11795660108327866\n",
      "Epoch: 2335 Loss: 0.11786749958992004\n",
      "Epoch: 2336 Loss: 0.11775083839893341\n",
      "Epoch: 2337 Loss: 0.11768194288015366\n",
      "Epoch: 2338 Loss: 0.1175265908241272\n",
      "Epoch: 2339 Loss: 0.11747796088457108\n",
      "Epoch: 2340 Loss: 0.11740333586931229\n",
      "Epoch: 2341 Loss: 0.11730468273162842\n",
      "Epoch: 2342 Loss: 0.11718983203172684\n",
      "Epoch: 2343 Loss: 0.11710739880800247\n",
      "Epoch: 2344 Loss: 0.11698529869318008\n",
      "Epoch: 2345 Loss: 0.11691338568925858\n",
      "Epoch: 2346 Loss: 0.11683853715658188\n",
      "Epoch: 2347 Loss: 0.1167425587773323\n",
      "Epoch: 2348 Loss: 0.116627037525177\n",
      "Epoch: 2349 Loss: 0.11657006293535233\n",
      "Epoch: 2350 Loss: 0.11644099652767181\n",
      "Epoch: 2351 Loss: 0.11636840552091599\n",
      "Epoch: 2352 Loss: 0.1162811890244484\n",
      "Epoch: 2353 Loss: 0.11622931063175201\n",
      "Epoch: 2354 Loss: 0.1160440668463707\n",
      "Epoch: 2355 Loss: 0.11602728813886642\n",
      "Epoch: 2356 Loss: 0.11583486199378967\n",
      "Epoch: 2357 Loss: 0.11581449210643768\n",
      "Epoch: 2358 Loss: 0.11568663269281387\n",
      "Epoch: 2359 Loss: 0.11566939204931259\n",
      "Epoch: 2360 Loss: 0.11555426567792892\n",
      "Epoch: 2361 Loss: 0.11547265201807022\n",
      "Epoch: 2362 Loss: 0.11537143588066101\n",
      "Epoch: 2363 Loss: 0.11528084427118301\n",
      "Epoch: 2364 Loss: 0.11519507318735123\n",
      "Epoch: 2365 Loss: 0.11511296033859253\n",
      "Epoch: 2366 Loss: 0.11506059020757675\n",
      "Epoch: 2367 Loss: 0.11491037160158157\n",
      "Epoch: 2368 Loss: 0.11487164348363876\n",
      "Epoch: 2369 Loss: 0.11475256085395813\n",
      "Epoch: 2370 Loss: 0.1146913543343544\n",
      "Epoch: 2371 Loss: 0.11457786709070206\n",
      "Epoch: 2372 Loss: 0.11444619297981262\n",
      "Epoch: 2373 Loss: 0.1144079715013504\n",
      "Epoch: 2374 Loss: 0.11427902430295944\n",
      "Epoch: 2375 Loss: 0.11423441767692566\n",
      "Epoch: 2376 Loss: 0.11430070549249649\n",
      "Epoch: 2377 Loss: 0.11403109133243561\n",
      "Epoch: 2378 Loss: 0.11394429951906204\n",
      "Epoch: 2379 Loss: 0.11383221298456192\n",
      "Epoch: 2380 Loss: 0.11375152319669724\n",
      "Epoch: 2381 Loss: 0.1136314794421196\n",
      "Epoch: 2382 Loss: 0.11355825513601303\n",
      "Epoch: 2383 Loss: 0.11352688074111938\n",
      "Epoch: 2384 Loss: 0.11342532932758331\n",
      "Epoch: 2385 Loss: 0.11334222555160522\n",
      "Epoch: 2386 Loss: 0.1132320761680603\n",
      "Epoch: 2387 Loss: 0.11309998482465744\n",
      "Epoch: 2388 Loss: 0.1130642369389534\n",
      "Epoch: 2389 Loss: 0.11298112571239471\n",
      "Epoch: 2390 Loss: 0.11292729526758194\n",
      "Epoch: 2391 Loss: 0.11283034831285477\n",
      "Epoch: 2392 Loss: 0.1126723513007164\n",
      "Epoch: 2393 Loss: 0.11257849633693695\n",
      "Epoch: 2394 Loss: 0.11264524608850479\n",
      "Epoch: 2395 Loss: 0.11243174225091934\n",
      "Epoch: 2396 Loss: 0.11234928667545319\n",
      "Epoch: 2397 Loss: 0.11224684864282608\n",
      "Epoch: 2398 Loss: 0.11212557554244995\n",
      "Epoch: 2399 Loss: 0.1120779886841774\n",
      "Epoch: 2400 Loss: 0.11201925575733185\n",
      "Epoch: 2401 Loss: 0.1119307428598404\n",
      "Epoch: 2402 Loss: 0.11184912174940109\n",
      "Epoch: 2403 Loss: 0.11175782233476639\n",
      "Epoch: 2404 Loss: 0.11170151084661484\n",
      "Epoch: 2405 Loss: 0.11157654225826263\n",
      "Epoch: 2406 Loss: 0.11149241775274277\n",
      "Epoch: 2407 Loss: 0.11141468584537506\n",
      "Epoch: 2408 Loss: 0.1113765612244606\n",
      "Epoch: 2409 Loss: 0.11123597621917725\n",
      "Epoch: 2410 Loss: 0.11111535876989365\n",
      "Epoch: 2411 Loss: 0.11104097217321396\n",
      "Epoch: 2412 Loss: 0.1109827309846878\n",
      "Epoch: 2413 Loss: 0.1109137237071991\n",
      "Epoch: 2414 Loss: 0.1107732504606247\n",
      "Epoch: 2415 Loss: 0.11068378388881683\n",
      "Epoch: 2416 Loss: 0.110643669962883\n",
      "Epoch: 2417 Loss: 0.11055096983909607\n",
      "Epoch: 2418 Loss: 0.11053670942783356\n",
      "Epoch: 2419 Loss: 0.11036039143800735\n",
      "Epoch: 2420 Loss: 0.11029335111379623\n",
      "Epoch: 2421 Loss: 0.11021919548511505\n",
      "Epoch: 2422 Loss: 0.11013353615999222\n",
      "Epoch: 2423 Loss: 0.11003578454256058\n",
      "Epoch: 2424 Loss: 0.11002924293279648\n",
      "Epoch: 2425 Loss: 0.10987478494644165\n",
      "Epoch: 2426 Loss: 0.10976026207208633\n",
      "Epoch: 2427 Loss: 0.10973052680492401\n",
      "Epoch: 2428 Loss: 0.1095840334892273\n",
      "Epoch: 2429 Loss: 0.10948548465967178\n",
      "Epoch: 2430 Loss: 0.10943823307752609\n",
      "Epoch: 2431 Loss: 0.10933931916952133\n",
      "Epoch: 2432 Loss: 0.10928035527467728\n",
      "Epoch: 2433 Loss: 0.109211266040802\n",
      "Epoch: 2434 Loss: 0.10912669450044632\n",
      "Epoch: 2435 Loss: 0.10900478065013885\n",
      "Epoch: 2436 Loss: 0.10894936323165894\n",
      "Epoch: 2437 Loss: 0.10881581157445908\n",
      "Epoch: 2438 Loss: 0.1088620200753212\n",
      "Epoch: 2439 Loss: 0.10870081186294556\n",
      "Epoch: 2440 Loss: 0.10858982056379318\n",
      "Epoch: 2441 Loss: 0.1085437536239624\n",
      "Epoch: 2442 Loss: 0.10850392282009125\n",
      "Epoch: 2443 Loss: 0.10837771743535995\n",
      "Epoch: 2444 Loss: 0.10825274139642715\n",
      "Epoch: 2445 Loss: 0.10819215327501297\n",
      "Epoch: 2446 Loss: 0.10809962451457977\n",
      "Epoch: 2447 Loss: 0.10808432847261429\n",
      "Epoch: 2448 Loss: 0.10795247554779053\n",
      "Epoch: 2449 Loss: 0.10786028951406479\n",
      "Epoch: 2450 Loss: 0.10775446146726608\n",
      "Epoch: 2451 Loss: 0.10769655555486679\n",
      "Epoch: 2452 Loss: 0.10759402811527252\n",
      "Epoch: 2453 Loss: 0.10753600299358368\n",
      "Epoch: 2454 Loss: 0.10753369331359863\n",
      "Epoch: 2455 Loss: 0.10736706107854843\n",
      "Epoch: 2456 Loss: 0.10729486495256424\n",
      "Epoch: 2457 Loss: 0.10716968029737473\n",
      "Epoch: 2458 Loss: 0.10710307955741882\n",
      "Epoch: 2459 Loss: 0.10699579119682312\n",
      "Epoch: 2460 Loss: 0.10700186342000961\n",
      "Epoch: 2461 Loss: 0.10688138753175735\n",
      "Epoch: 2462 Loss: 0.10673818737268448\n",
      "Epoch: 2463 Loss: 0.10671260952949524\n",
      "Epoch: 2464 Loss: 0.10660362988710403\n",
      "Epoch: 2465 Loss: 0.10662295669317245\n",
      "Epoch: 2466 Loss: 0.1064903512597084\n",
      "Epoch: 2467 Loss: 0.1064644381403923\n",
      "Epoch: 2468 Loss: 0.10628676414489746\n",
      "Epoch: 2469 Loss: 0.10620314627885818\n",
      "Epoch: 2470 Loss: 0.10613609105348587\n",
      "Epoch: 2471 Loss: 0.10605161637067795\n",
      "Epoch: 2472 Loss: 0.10597309470176697\n",
      "Epoch: 2473 Loss: 0.10587145388126373\n",
      "Epoch: 2474 Loss: 0.10579385608434677\n",
      "Epoch: 2475 Loss: 0.10578300058841705\n",
      "Epoch: 2476 Loss: 0.10567069798707962\n",
      "Epoch: 2477 Loss: 0.10558231174945831\n",
      "Epoch: 2478 Loss: 0.1054796576499939\n",
      "Epoch: 2479 Loss: 0.10539707541465759\n",
      "Epoch: 2480 Loss: 0.10540644824504852\n",
      "Epoch: 2481 Loss: 0.10524433106184006\n",
      "Epoch: 2482 Loss: 0.10517995059490204\n",
      "Epoch: 2483 Loss: 0.10510674118995667\n",
      "Epoch: 2484 Loss: 0.10498447716236115\n",
      "Epoch: 2485 Loss: 0.10493656247854233\n",
      "Epoch: 2486 Loss: 0.104831762611866\n",
      "Epoch: 2487 Loss: 0.10472278296947479\n",
      "Epoch: 2488 Loss: 0.10471753031015396\n",
      "Epoch: 2489 Loss: 0.10464334487915039\n",
      "Epoch: 2490 Loss: 0.1044948548078537\n",
      "Epoch: 2491 Loss: 0.10442136228084564\n",
      "Epoch: 2492 Loss: 0.1043730080127716\n",
      "Epoch: 2493 Loss: 0.10427508503198624\n",
      "Epoch: 2494 Loss: 0.10430574417114258\n",
      "Epoch: 2495 Loss: 0.10412009805440903\n",
      "Epoch: 2496 Loss: 0.10411036759614944\n",
      "Epoch: 2497 Loss: 0.10395914316177368\n",
      "Epoch: 2498 Loss: 0.10392135381698608\n",
      "Epoch: 2499 Loss: 0.10379348695278168\n",
      "Epoch: 2500 Loss: 0.10373339802026749\n",
      "Epoch: 2501 Loss: 0.1036612018942833\n",
      "Epoch: 2502 Loss: 0.10357943177223206\n",
      "Epoch: 2503 Loss: 0.10350151360034943\n",
      "Epoch: 2504 Loss: 0.10347075760364532\n",
      "Epoch: 2505 Loss: 0.1033489778637886\n",
      "Epoch: 2506 Loss: 0.10325870662927628\n",
      "Epoch: 2507 Loss: 0.10315737873315811\n",
      "Epoch: 2508 Loss: 0.1030658408999443\n",
      "Epoch: 2509 Loss: 0.10296737402677536\n",
      "Epoch: 2510 Loss: 0.10294201225042343\n",
      "Epoch: 2511 Loss: 0.10284726321697235\n",
      "Epoch: 2512 Loss: 0.10284361243247986\n",
      "Epoch: 2513 Loss: 0.10269278287887573\n",
      "Epoch: 2514 Loss: 0.1026691272854805\n",
      "Epoch: 2515 Loss: 0.10250720381736755\n",
      "Epoch: 2516 Loss: 0.10241011530160904\n",
      "Epoch: 2517 Loss: 0.10238327085971832\n",
      "Epoch: 2518 Loss: 0.10228448361158371\n",
      "Epoch: 2519 Loss: 0.10223178565502167\n",
      "Epoch: 2520 Loss: 0.10220381617546082\n",
      "Epoch: 2521 Loss: 0.10208015143871307\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2522 Loss: 0.10196613520383835\n",
      "Epoch: 2523 Loss: 0.10189534723758698\n",
      "Epoch: 2524 Loss: 0.10180655121803284\n",
      "Epoch: 2525 Loss: 0.10179479420185089\n",
      "Epoch: 2526 Loss: 0.10169700533151627\n",
      "Epoch: 2527 Loss: 0.10156562179327011\n",
      "Epoch: 2528 Loss: 0.10150012373924255\n",
      "Epoch: 2529 Loss: 0.1015271246433258\n",
      "Epoch: 2530 Loss: 0.10137586295604706\n",
      "Epoch: 2531 Loss: 0.10125355422496796\n",
      "Epoch: 2532 Loss: 0.10121708363294601\n",
      "Epoch: 2533 Loss: 0.1011032685637474\n",
      "Epoch: 2534 Loss: 0.101102314889431\n",
      "Epoch: 2535 Loss: 0.1009720116853714\n",
      "Epoch: 2536 Loss: 0.10090232640504837\n",
      "Epoch: 2537 Loss: 0.10086604952812195\n",
      "Epoch: 2538 Loss: 0.10076442360877991\n",
      "Epoch: 2539 Loss: 0.10071643441915512\n",
      "Epoch: 2540 Loss: 0.10058549046516418\n",
      "Epoch: 2541 Loss: 0.10049436241388321\n",
      "Epoch: 2542 Loss: 0.10044986009597778\n",
      "Epoch: 2543 Loss: 0.10034430027008057\n",
      "Epoch: 2544 Loss: 0.10030896216630936\n",
      "Epoch: 2545 Loss: 0.10021300613880157\n",
      "Epoch: 2546 Loss: 0.10015030950307846\n",
      "Epoch: 2547 Loss: 0.10009554773569107\n",
      "Epoch: 2548 Loss: 0.10002364218235016\n",
      "Epoch: 2549 Loss: 0.09990083426237106\n",
      "Epoch: 2550 Loss: 0.09981104731559753\n",
      "Epoch: 2551 Loss: 0.09976886212825775\n",
      "Epoch: 2552 Loss: 0.09968521445989609\n",
      "Epoch: 2553 Loss: 0.09957646578550339\n",
      "Epoch: 2554 Loss: 0.09954901039600372\n",
      "Epoch: 2555 Loss: 0.0994100570678711\n",
      "Epoch: 2556 Loss: 0.09936922043561935\n",
      "Epoch: 2557 Loss: 0.09926565736532211\n",
      "Epoch: 2558 Loss: 0.09927382320165634\n",
      "Epoch: 2559 Loss: 0.09911173582077026\n",
      "Epoch: 2560 Loss: 0.09899633377790451\n",
      "Epoch: 2561 Loss: 0.09897012263536453\n",
      "Epoch: 2562 Loss: 0.09890072047710419\n",
      "Epoch: 2563 Loss: 0.0987851545214653\n",
      "Epoch: 2564 Loss: 0.09873649477958679\n",
      "Epoch: 2565 Loss: 0.09859637916088104\n",
      "Epoch: 2566 Loss: 0.09864160418510437\n",
      "Epoch: 2567 Loss: 0.09852640330791473\n",
      "Epoch: 2568 Loss: 0.09839451313018799\n",
      "Epoch: 2569 Loss: 0.09832648932933807\n",
      "Epoch: 2570 Loss: 0.09823014587163925\n",
      "Epoch: 2571 Loss: 0.09823810309171677\n",
      "Epoch: 2572 Loss: 0.09808667004108429\n",
      "Epoch: 2573 Loss: 0.09803294390439987\n",
      "Epoch: 2574 Loss: 0.09792768210172653\n",
      "Epoch: 2575 Loss: 0.09789326786994934\n",
      "Epoch: 2576 Loss: 0.09775558859109879\n",
      "Epoch: 2577 Loss: 0.09772388637065887\n",
      "Epoch: 2578 Loss: 0.09762263298034668\n",
      "Epoch: 2579 Loss: 0.09755019098520279\n",
      "Epoch: 2580 Loss: 0.09750629961490631\n",
      "Epoch: 2581 Loss: 0.09741496294736862\n",
      "Epoch: 2582 Loss: 0.09730024635791779\n",
      "Epoch: 2583 Loss: 0.09730259329080582\n",
      "Epoch: 2584 Loss: 0.09718310832977295\n",
      "Epoch: 2585 Loss: 0.09709259867668152\n",
      "Epoch: 2586 Loss: 0.09702540934085846\n",
      "Epoch: 2587 Loss: 0.09699900448322296\n",
      "Epoch: 2588 Loss: 0.09685094654560089\n",
      "Epoch: 2589 Loss: 0.09682624042034149\n",
      "Epoch: 2590 Loss: 0.09672528505325317\n",
      "Epoch: 2591 Loss: 0.09664794057607651\n",
      "Epoch: 2592 Loss: 0.09654281288385391\n",
      "Epoch: 2593 Loss: 0.09648877382278442\n",
      "Epoch: 2594 Loss: 0.09643738716840744\n",
      "Epoch: 2595 Loss: 0.09638766199350357\n",
      "Epoch: 2596 Loss: 0.09630274772644043\n",
      "Epoch: 2597 Loss: 0.09621720761060715\n",
      "Epoch: 2598 Loss: 0.09614186733961105\n",
      "Epoch: 2599 Loss: 0.09605761617422104\n",
      "Epoch: 2600 Loss: 0.09601820260286331\n",
      "Epoch: 2601 Loss: 0.09591639041900635\n",
      "Epoch: 2602 Loss: 0.09580153971910477\n",
      "Epoch: 2603 Loss: 0.09577988088130951\n",
      "Epoch: 2604 Loss: 0.09566616266965866\n",
      "Epoch: 2605 Loss: 0.0956633910536766\n",
      "Epoch: 2606 Loss: 0.09553852677345276\n",
      "Epoch: 2607 Loss: 0.09542074054479599\n",
      "Epoch: 2608 Loss: 0.09535495191812515\n",
      "Epoch: 2609 Loss: 0.0952860489487648\n",
      "Epoch: 2610 Loss: 0.09520960599184036\n",
      "Epoch: 2611 Loss: 0.09521769732236862\n",
      "Epoch: 2612 Loss: 0.09508805721998215\n",
      "Epoch: 2613 Loss: 0.0950460433959961\n",
      "Epoch: 2614 Loss: 0.09490367025136948\n",
      "Epoch: 2615 Loss: 0.09490740299224854\n",
      "Epoch: 2616 Loss: 0.09478479623794556\n",
      "Epoch: 2617 Loss: 0.09476125985383987\n",
      "Epoch: 2618 Loss: 0.09466181695461273\n",
      "Epoch: 2619 Loss: 0.09454605728387833\n",
      "Epoch: 2620 Loss: 0.09448433667421341\n",
      "Epoch: 2621 Loss: 0.09443330019712448\n",
      "Epoch: 2622 Loss: 0.09437830001115799\n",
      "Epoch: 2623 Loss: 0.09424570947885513\n",
      "Epoch: 2624 Loss: 0.09424177557229996\n",
      "Epoch: 2625 Loss: 0.09411130845546722\n",
      "Epoch: 2626 Loss: 0.09407918900251389\n",
      "Epoch: 2627 Loss: 0.09400143474340439\n",
      "Epoch: 2628 Loss: 0.09390910714864731\n",
      "Epoch: 2629 Loss: 0.09381529688835144\n",
      "Epoch: 2630 Loss: 0.09376692771911621\n",
      "Epoch: 2631 Loss: 0.09367995709180832\n",
      "Epoch: 2632 Loss: 0.09365979582071304\n",
      "Epoch: 2633 Loss: 0.09354851394891739\n",
      "Epoch: 2634 Loss: 0.0935235470533371\n",
      "Epoch: 2635 Loss: 0.09336230903863907\n",
      "Epoch: 2636 Loss: 0.0932798907160759\n",
      "Epoch: 2637 Loss: 0.09330150485038757\n",
      "Epoch: 2638 Loss: 0.09317928552627563\n",
      "Epoch: 2639 Loss: 0.09310459345579147\n",
      "Epoch: 2640 Loss: 0.09306685626506805\n",
      "Epoch: 2641 Loss: 0.09294984489679337\n",
      "Epoch: 2642 Loss: 0.09288332611322403\n",
      "Epoch: 2643 Loss: 0.09280785918235779\n",
      "Epoch: 2644 Loss: 0.09269116073846817\n",
      "Epoch: 2645 Loss: 0.09268762916326523\n",
      "Epoch: 2646 Loss: 0.092619389295578\n",
      "Epoch: 2647 Loss: 0.09252262115478516\n",
      "Epoch: 2648 Loss: 0.09245459735393524\n",
      "Epoch: 2649 Loss: 0.09239370375871658\n",
      "Epoch: 2650 Loss: 0.09227310121059418\n",
      "Epoch: 2651 Loss: 0.09223093837499619\n",
      "Epoch: 2652 Loss: 0.09222734719514847\n",
      "Epoch: 2653 Loss: 0.0920739471912384\n",
      "Epoch: 2654 Loss: 0.0920117124915123\n",
      "Epoch: 2655 Loss: 0.09193158149719238\n",
      "Epoch: 2656 Loss: 0.09187842905521393\n",
      "Epoch: 2657 Loss: 0.09179280698299408\n",
      "Epoch: 2658 Loss: 0.09175869822502136\n",
      "Epoch: 2659 Loss: 0.09168216586112976\n",
      "Epoch: 2660 Loss: 0.09158626943826675\n",
      "Epoch: 2661 Loss: 0.09150576591491699\n",
      "Epoch: 2662 Loss: 0.0914747565984726\n",
      "Epoch: 2663 Loss: 0.09136438369750977\n",
      "Epoch: 2664 Loss: 0.09129404276609421\n",
      "Epoch: 2665 Loss: 0.09124881774187088\n",
      "Epoch: 2666 Loss: 0.09116216748952866\n",
      "Epoch: 2667 Loss: 0.09111800044775009\n",
      "Epoch: 2668 Loss: 0.09103234857320786\n",
      "Epoch: 2669 Loss: 0.09095191210508347\n",
      "Epoch: 2670 Loss: 0.09090518951416016\n",
      "Epoch: 2671 Loss: 0.09084238111972809\n",
      "Epoch: 2672 Loss: 0.0907425582408905\n",
      "Epoch: 2673 Loss: 0.0906585305929184\n",
      "Epoch: 2674 Loss: 0.0906326025724411\n",
      "Epoch: 2675 Loss: 0.09052775800228119\n",
      "Epoch: 2676 Loss: 0.09047404676675797\n",
      "Epoch: 2677 Loss: 0.09041719138622284\n",
      "Epoch: 2678 Loss: 0.09033610671758652\n",
      "Epoch: 2679 Loss: 0.09027457982301712\n",
      "Epoch: 2680 Loss: 0.09017892181873322\n",
      "Epoch: 2681 Loss: 0.09009307622909546\n",
      "Epoch: 2682 Loss: 0.09005400538444519\n",
      "Epoch: 2683 Loss: 0.08998879790306091\n",
      "Epoch: 2684 Loss: 0.08994976431131363\n",
      "Epoch: 2685 Loss: 0.08985770493745804\n",
      "Epoch: 2686 Loss: 0.08980925381183624\n",
      "Epoch: 2687 Loss: 0.08972389250993729\n",
      "Epoch: 2688 Loss: 0.08966321498155594\n",
      "Epoch: 2689 Loss: 0.08957517892122269\n",
      "Epoch: 2690 Loss: 0.08951938152313232\n",
      "Epoch: 2691 Loss: 0.08951479941606522\n",
      "Epoch: 2692 Loss: 0.08937595784664154\n",
      "Epoch: 2693 Loss: 0.08937273919582367\n",
      "Epoch: 2694 Loss: 0.08932742476463318\n",
      "Epoch: 2695 Loss: 0.08919651806354523\n",
      "Epoch: 2696 Loss: 0.0891319215297699\n",
      "Epoch: 2697 Loss: 0.0890842005610466\n",
      "Epoch: 2698 Loss: 0.08898605406284332\n",
      "Epoch: 2699 Loss: 0.08895421773195267\n",
      "Epoch: 2700 Loss: 0.08886557072401047\n",
      "Epoch: 2701 Loss: 0.08878769725561142\n",
      "Epoch: 2702 Loss: 0.08874460309743881\n",
      "Epoch: 2703 Loss: 0.0886792317032814\n",
      "Epoch: 2704 Loss: 0.0885637179017067\n",
      "Epoch: 2705 Loss: 0.08855308592319489\n",
      "Epoch: 2706 Loss: 0.088496133685112\n",
      "Epoch: 2707 Loss: 0.08842317759990692\n",
      "Epoch: 2708 Loss: 0.08835411071777344\n",
      "Epoch: 2709 Loss: 0.08826751261949539\n",
      "Epoch: 2710 Loss: 0.0882502943277359\n",
      "Epoch: 2711 Loss: 0.08813762664794922\n",
      "Epoch: 2712 Loss: 0.08808102458715439\n",
      "Epoch: 2713 Loss: 0.08803485333919525\n",
      "Epoch: 2714 Loss: 0.08797156065702438\n",
      "Epoch: 2715 Loss: 0.08788145333528519\n",
      "Epoch: 2716 Loss: 0.08781956136226654\n",
      "Epoch: 2717 Loss: 0.0877450481057167\n",
      "Epoch: 2718 Loss: 0.08769802749156952\n",
      "Epoch: 2719 Loss: 0.08760186284780502\n",
      "Epoch: 2720 Loss: 0.08756006509065628\n",
      "Epoch: 2721 Loss: 0.08750031888484955\n",
      "Epoch: 2722 Loss: 0.08740253746509552\n",
      "Epoch: 2723 Loss: 0.08737606555223465\n",
      "Epoch: 2724 Loss: 0.0873052105307579\n",
      "Epoch: 2725 Loss: 0.08724580705165863\n",
      "Epoch: 2726 Loss: 0.08716344088315964\n",
      "Epoch: 2727 Loss: 0.08714722096920013\n",
      "Epoch: 2728 Loss: 0.08704249560832977\n",
      "Epoch: 2729 Loss: 0.08696679025888443\n",
      "Epoch: 2730 Loss: 0.08691763877868652\n",
      "Epoch: 2731 Loss: 0.08685334771871567\n",
      "Epoch: 2732 Loss: 0.0867830142378807\n",
      "Epoch: 2733 Loss: 0.08671440929174423\n",
      "Epoch: 2734 Loss: 0.08667667210102081\n",
      "Epoch: 2735 Loss: 0.08659526705741882\n",
      "Epoch: 2736 Loss: 0.08652541786432266\n",
      "Epoch: 2737 Loss: 0.0864667296409607\n",
      "Epoch: 2738 Loss: 0.08638471364974976\n",
      "Epoch: 2739 Loss: 0.08631958067417145\n",
      "Epoch: 2740 Loss: 0.08629374206066132\n",
      "Epoch: 2741 Loss: 0.08621974289417267\n",
      "Epoch: 2742 Loss: 0.08615133166313171\n",
      "Epoch: 2743 Loss: 0.08613714575767517\n",
      "Epoch: 2744 Loss: 0.0860060304403305\n",
      "Epoch: 2745 Loss: 0.08594649285078049\n",
      "Epoch: 2746 Loss: 0.08588988333940506\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2747 Loss: 0.08582881093025208\n",
      "Epoch: 2748 Loss: 0.0857662782073021\n",
      "Epoch: 2749 Loss: 0.08567773550748825\n",
      "Epoch: 2750 Loss: 0.08563879132270813\n",
      "Epoch: 2751 Loss: 0.08557826280593872\n",
      "Epoch: 2752 Loss: 0.08553162217140198\n",
      "Epoch: 2753 Loss: 0.08547022938728333\n",
      "Epoch: 2754 Loss: 0.08538629859685898\n",
      "Epoch: 2755 Loss: 0.08532965928316116\n",
      "Epoch: 2756 Loss: 0.08527617156505585\n",
      "Epoch: 2757 Loss: 0.08518342673778534\n",
      "Epoch: 2758 Loss: 0.0851394534111023\n",
      "Epoch: 2759 Loss: 0.08507613837718964\n",
      "Epoch: 2760 Loss: 0.08500506728887558\n",
      "Epoch: 2761 Loss: 0.08496491611003876\n",
      "Epoch: 2762 Loss: 0.08490445464849472\n",
      "Epoch: 2763 Loss: 0.08489058166742325\n",
      "Epoch: 2764 Loss: 0.08478289842605591\n",
      "Epoch: 2765 Loss: 0.08472540974617004\n",
      "Epoch: 2766 Loss: 0.08464295417070389\n",
      "Epoch: 2767 Loss: 0.08458083868026733\n",
      "Epoch: 2768 Loss: 0.08453255146741867\n",
      "Epoch: 2769 Loss: 0.08445384353399277\n",
      "Epoch: 2770 Loss: 0.08441763371229172\n",
      "Epoch: 2771 Loss: 0.08430495113134384\n",
      "Epoch: 2772 Loss: 0.0842999517917633\n",
      "Epoch: 2773 Loss: 0.08425932377576828\n",
      "Epoch: 2774 Loss: 0.08415248990058899\n",
      "Epoch: 2775 Loss: 0.08410238474607468\n",
      "Epoch: 2776 Loss: 0.08403433114290237\n",
      "Epoch: 2777 Loss: 0.08398014307022095\n",
      "Epoch: 2778 Loss: 0.08390144258737564\n",
      "Epoch: 2779 Loss: 0.08385223895311356\n",
      "Epoch: 2780 Loss: 0.08378682285547256\n",
      "Epoch: 2781 Loss: 0.08373253792524338\n",
      "Epoch: 2782 Loss: 0.08367837965488434\n",
      "Epoch: 2783 Loss: 0.08365923911333084\n",
      "Epoch: 2784 Loss: 0.0836017057299614\n",
      "Epoch: 2785 Loss: 0.08348522335290909\n",
      "Epoch: 2786 Loss: 0.08342870324850082\n",
      "Epoch: 2787 Loss: 0.08335689455270767\n",
      "Epoch: 2788 Loss: 0.0833105742931366\n",
      "Epoch: 2789 Loss: 0.08322900533676147\n",
      "Epoch: 2790 Loss: 0.0831911563873291\n",
      "Epoch: 2791 Loss: 0.08312805742025375\n",
      "Epoch: 2792 Loss: 0.08308136463165283\n",
      "Epoch: 2793 Loss: 0.0830119401216507\n",
      "Epoch: 2794 Loss: 0.08295377343893051\n",
      "Epoch: 2795 Loss: 0.08290591090917587\n",
      "Epoch: 2796 Loss: 0.08284147083759308\n",
      "Epoch: 2797 Loss: 0.08275553584098816\n",
      "Epoch: 2798 Loss: 0.0827193632721901\n",
      "Epoch: 2799 Loss: 0.08269799500703812\n",
      "Epoch: 2800 Loss: 0.08256153762340546\n",
      "Epoch: 2801 Loss: 0.08255361765623093\n",
      "Epoch: 2802 Loss: 0.08246765285730362\n",
      "Epoch: 2803 Loss: 0.08242132514715195\n",
      "Epoch: 2804 Loss: 0.08237304538488388\n",
      "Epoch: 2805 Loss: 0.08230483531951904\n",
      "Epoch: 2806 Loss: 0.08224085718393326\n",
      "Epoch: 2807 Loss: 0.08220035582780838\n",
      "Epoch: 2808 Loss: 0.08210522681474686\n",
      "Epoch: 2809 Loss: 0.08207271248102188\n",
      "Epoch: 2810 Loss: 0.08199841529130936\n",
      "Epoch: 2811 Loss: 0.0819372683763504\n",
      "Epoch: 2812 Loss: 0.08190670609474182\n",
      "Epoch: 2813 Loss: 0.08183993399143219\n",
      "Epoch: 2814 Loss: 0.08184049278497696\n",
      "Epoch: 2815 Loss: 0.08169165253639221\n",
      "Epoch: 2816 Loss: 0.08167808502912521\n",
      "Epoch: 2817 Loss: 0.08158671110868454\n",
      "Epoch: 2818 Loss: 0.08156252652406693\n",
      "Epoch: 2819 Loss: 0.0814492478966713\n",
      "Epoch: 2820 Loss: 0.08146311342716217\n",
      "Epoch: 2821 Loss: 0.08134648948907852\n",
      "Epoch: 2822 Loss: 0.08129681646823883\n",
      "Epoch: 2823 Loss: 0.08124912530183792\n",
      "Epoch: 2824 Loss: 0.08120150864124298\n",
      "Epoch: 2825 Loss: 0.08115743100643158\n",
      "Epoch: 2826 Loss: 0.08108273148536682\n",
      "Epoch: 2827 Loss: 0.08108080178499222\n",
      "Epoch: 2828 Loss: 0.08095362037420273\n",
      "Epoch: 2829 Loss: 0.08091304451227188\n",
      "Epoch: 2830 Loss: 0.08084013313055038\n",
      "Epoch: 2831 Loss: 0.08080300688743591\n",
      "Epoch: 2832 Loss: 0.08073357492685318\n",
      "Epoch: 2833 Loss: 0.08069588989019394\n",
      "Epoch: 2834 Loss: 0.0806451365351677\n",
      "Epoch: 2835 Loss: 0.08057153224945068\n",
      "Epoch: 2836 Loss: 0.08054724335670471\n",
      "Epoch: 2837 Loss: 0.08044209331274033\n",
      "Epoch: 2838 Loss: 0.08041226118803024\n",
      "Epoch: 2839 Loss: 0.08032435178756714\n",
      "Epoch: 2840 Loss: 0.08029303699731827\n",
      "Epoch: 2841 Loss: 0.0802050456404686\n",
      "Epoch: 2842 Loss: 0.08019909262657166\n",
      "Epoch: 2843 Loss: 0.08011588454246521\n",
      "Epoch: 2844 Loss: 0.08008594810962677\n",
      "Epoch: 2845 Loss: 0.08001071214675903\n",
      "Epoch: 2846 Loss: 0.0799422562122345\n",
      "Epoch: 2847 Loss: 0.0799231305718422\n",
      "Epoch: 2848 Loss: 0.07981148362159729\n",
      "Epoch: 2849 Loss: 0.07979932427406311\n",
      "Epoch: 2850 Loss: 0.07973048835992813\n",
      "Epoch: 2851 Loss: 0.07969421148300171\n",
      "Epoch: 2852 Loss: 0.07961583882570267\n",
      "Epoch: 2853 Loss: 0.07957509905099869\n",
      "Epoch: 2854 Loss: 0.07948486506938934\n",
      "Epoch: 2855 Loss: 0.07944782078266144\n",
      "Epoch: 2856 Loss: 0.07939133048057556\n",
      "Epoch: 2857 Loss: 0.07932573556900024\n",
      "Epoch: 2858 Loss: 0.07928582280874252\n",
      "Epoch: 2859 Loss: 0.07920096069574356\n",
      "Epoch: 2860 Loss: 0.07918775081634521\n",
      "Epoch: 2861 Loss: 0.07907263934612274\n",
      "Epoch: 2862 Loss: 0.07907823473215103\n",
      "Epoch: 2863 Loss: 0.0790310874581337\n",
      "Epoch: 2864 Loss: 0.07897845655679703\n",
      "Epoch: 2865 Loss: 0.07889565080404282\n",
      "Epoch: 2866 Loss: 0.07882381975650787\n",
      "Epoch: 2867 Loss: 0.07879079878330231\n",
      "Epoch: 2868 Loss: 0.07870921492576599\n",
      "Epoch: 2869 Loss: 0.07866654545068741\n",
      "Epoch: 2870 Loss: 0.0786060094833374\n",
      "Epoch: 2871 Loss: 0.07858020812273026\n",
      "Epoch: 2872 Loss: 0.07847744971513748\n",
      "Epoch: 2873 Loss: 0.07844308018684387\n",
      "Epoch: 2874 Loss: 0.07842184603214264\n",
      "Epoch: 2875 Loss: 0.07832787930965424\n",
      "Epoch: 2876 Loss: 0.07834059000015259\n",
      "Epoch: 2877 Loss: 0.07824618369340897\n",
      "Epoch: 2878 Loss: 0.07815553992986679\n",
      "Epoch: 2879 Loss: 0.07814650982618332\n",
      "Epoch: 2880 Loss: 0.07805564254522324\n",
      "Epoch: 2881 Loss: 0.07804834097623825\n",
      "Epoch: 2882 Loss: 0.07794799655675888\n",
      "Epoch: 2883 Loss: 0.07790012657642365\n",
      "Epoch: 2884 Loss: 0.07788760215044022\n",
      "Epoch: 2885 Loss: 0.0778086706995964\n",
      "Epoch: 2886 Loss: 0.07773944735527039\n",
      "Epoch: 2887 Loss: 0.07769932597875595\n",
      "Epoch: 2888 Loss: 0.07764072716236115\n",
      "Epoch: 2889 Loss: 0.07759285718202591\n",
      "Epoch: 2890 Loss: 0.07751231640577316\n",
      "Epoch: 2891 Loss: 0.07747204601764679\n",
      "Epoch: 2892 Loss: 0.0774197056889534\n",
      "Epoch: 2893 Loss: 0.07738833874464035\n",
      "Epoch: 2894 Loss: 0.07733114808797836\n",
      "Epoch: 2895 Loss: 0.07728683948516846\n",
      "Epoch: 2896 Loss: 0.077255018055439\n",
      "Epoch: 2897 Loss: 0.07714913040399551\n",
      "Epoch: 2898 Loss: 0.077104352414608\n",
      "Epoch: 2899 Loss: 0.0770760253071785\n",
      "Epoch: 2900 Loss: 0.07699847966432571\n",
      "Epoch: 2901 Loss: 0.07695426046848297\n",
      "Epoch: 2902 Loss: 0.07690680027008057\n",
      "Epoch: 2903 Loss: 0.07682716846466064\n",
      "Epoch: 2904 Loss: 0.07682478427886963\n",
      "Epoch: 2905 Loss: 0.07673057913780212\n",
      "Epoch: 2906 Loss: 0.07669581472873688\n",
      "Epoch: 2907 Loss: 0.0766463428735733\n",
      "Epoch: 2908 Loss: 0.0766039565205574\n",
      "Epoch: 2909 Loss: 0.07651354372501373\n",
      "Epoch: 2910 Loss: 0.07646528631448746\n",
      "Epoch: 2911 Loss: 0.0764537826180458\n",
      "Epoch: 2912 Loss: 0.0764090046286583\n",
      "Epoch: 2913 Loss: 0.07632917165756226\n",
      "Epoch: 2914 Loss: 0.07631666213274002\n",
      "Epoch: 2915 Loss: 0.07621850818395615\n",
      "Epoch: 2916 Loss: 0.07619242370128632\n",
      "Epoch: 2917 Loss: 0.07610456645488739\n",
      "Epoch: 2918 Loss: 0.07605711370706558\n",
      "Epoch: 2919 Loss: 0.07600904256105423\n",
      "Epoch: 2920 Loss: 0.07593681663274765\n",
      "Epoch: 2921 Loss: 0.07588475942611694\n",
      "Epoch: 2922 Loss: 0.07589973509311676\n",
      "Epoch: 2923 Loss: 0.07580573111772537\n",
      "Epoch: 2924 Loss: 0.07574018836021423\n",
      "Epoch: 2925 Loss: 0.07570378482341766\n",
      "Epoch: 2926 Loss: 0.07563036680221558\n",
      "Epoch: 2927 Loss: 0.0755593329668045\n",
      "Epoch: 2928 Loss: 0.07553239166736603\n",
      "Epoch: 2929 Loss: 0.07546214759349823\n",
      "Epoch: 2930 Loss: 0.07540485262870789\n",
      "Epoch: 2931 Loss: 0.07539655268192291\n",
      "Epoch: 2932 Loss: 0.0752982497215271\n",
      "Epoch: 2933 Loss: 0.07528378814458847\n",
      "Epoch: 2934 Loss: 0.07519220560789108\n",
      "Epoch: 2935 Loss: 0.0751783475279808\n",
      "Epoch: 2936 Loss: 0.0750841349363327\n",
      "Epoch: 2937 Loss: 0.07505866885185242\n",
      "Epoch: 2938 Loss: 0.0749916285276413\n",
      "Epoch: 2939 Loss: 0.07494829595088959\n",
      "Epoch: 2940 Loss: 0.07490836083889008\n",
      "Epoch: 2941 Loss: 0.0748230516910553\n",
      "Epoch: 2942 Loss: 0.0747811570763588\n",
      "Epoch: 2943 Loss: 0.0747489407658577\n",
      "Epoch: 2944 Loss: 0.07465625554323196\n",
      "Epoch: 2945 Loss: 0.07461714744567871\n",
      "Epoch: 2946 Loss: 0.07458005100488663\n",
      "Epoch: 2947 Loss: 0.07450200617313385\n",
      "Epoch: 2948 Loss: 0.07447441667318344\n",
      "Epoch: 2949 Loss: 0.0744030624628067\n",
      "Epoch: 2950 Loss: 0.07436884939670563\n",
      "Epoch: 2951 Loss: 0.0743137076497078\n",
      "Epoch: 2952 Loss: 0.07429088652133942\n",
      "Epoch: 2953 Loss: 0.07418794929981232\n",
      "Epoch: 2954 Loss: 0.07415454834699631\n",
      "Epoch: 2955 Loss: 0.07407425343990326\n",
      "Epoch: 2956 Loss: 0.074066661298275\n",
      "Epoch: 2957 Loss: 0.07396478205919266\n",
      "Epoch: 2958 Loss: 0.07394804805517197\n",
      "Epoch: 2959 Loss: 0.07394195348024368\n",
      "Epoch: 2960 Loss: 0.07384902238845825\n",
      "Epoch: 2961 Loss: 0.07379162311553955\n",
      "Epoch: 2962 Loss: 0.07372858375310898\n",
      "Epoch: 2963 Loss: 0.07367579638957977\n",
      "Epoch: 2964 Loss: 0.07363250106573105\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2965 Loss: 0.07356490939855576\n",
      "Epoch: 2966 Loss: 0.07354529947042465\n",
      "Epoch: 2967 Loss: 0.07350584864616394\n",
      "Epoch: 2968 Loss: 0.07340973615646362\n",
      "Epoch: 2969 Loss: 0.07335770130157471\n",
      "Epoch: 2970 Loss: 0.07330495119094849\n",
      "Epoch: 2971 Loss: 0.07325077801942825\n",
      "Epoch: 2972 Loss: 0.07318269461393356\n",
      "Epoch: 2973 Loss: 0.07317445427179337\n",
      "Epoch: 2974 Loss: 0.07309599965810776\n",
      "Epoch: 2975 Loss: 0.07305600494146347\n",
      "Epoch: 2976 Loss: 0.07301383465528488\n",
      "Epoch: 2977 Loss: 0.07298077642917633\n",
      "Epoch: 2978 Loss: 0.07291411608457565\n",
      "Epoch: 2979 Loss: 0.0728539228439331\n",
      "Epoch: 2980 Loss: 0.07278301566839218\n",
      "Epoch: 2981 Loss: 0.07276879996061325\n",
      "Epoch: 2982 Loss: 0.07269257307052612\n",
      "Epoch: 2983 Loss: 0.07265006005764008\n",
      "Epoch: 2984 Loss: 0.07264236360788345\n",
      "Epoch: 2985 Loss: 0.07255671918392181\n",
      "Epoch: 2986 Loss: 0.07255373150110245\n",
      "Epoch: 2987 Loss: 0.07245618104934692\n",
      "Epoch: 2988 Loss: 0.07243408262729645\n",
      "Epoch: 2989 Loss: 0.07237093895673752\n",
      "Epoch: 2990 Loss: 0.0723646730184555\n",
      "Epoch: 2991 Loss: 0.07225627452135086\n",
      "Epoch: 2992 Loss: 0.07220830768346786\n",
      "Epoch: 2993 Loss: 0.07216838747262955\n",
      "Epoch: 2994 Loss: 0.07207702845335007\n",
      "Epoch: 2995 Loss: 0.07205222547054291\n",
      "Epoch: 2996 Loss: 0.07205958664417267\n",
      "Epoch: 2997 Loss: 0.07195122539997101\n",
      "Epoch: 2998 Loss: 0.07192162424325943\n",
      "Epoch: 2999 Loss: 0.07185433059930801\n",
      "Epoch: 3000 Loss: 0.07180285453796387\n",
      "Epoch: 3001 Loss: 0.07178300619125366\n",
      "Epoch: 3002 Loss: 0.07172336429357529\n",
      "Epoch: 3003 Loss: 0.0716632753610611\n",
      "Epoch: 3004 Loss: 0.07165617495775223\n",
      "Epoch: 3005 Loss: 0.07158973067998886\n",
      "Epoch: 3006 Loss: 0.07151852548122406\n",
      "Epoch: 3007 Loss: 0.07146938890218735\n",
      "Epoch: 3008 Loss: 0.0714326873421669\n",
      "Epoch: 3009 Loss: 0.07137566059827805\n",
      "Epoch: 3010 Loss: 0.07134402543306351\n",
      "Epoch: 3011 Loss: 0.0712815672159195\n",
      "Epoch: 3012 Loss: 0.0712413340806961\n",
      "Epoch: 3013 Loss: 0.07118362188339233\n",
      "Epoch: 3014 Loss: 0.07116571068763733\n",
      "Epoch: 3015 Loss: 0.07110536098480225\n",
      "Epoch: 3016 Loss: 0.07104211300611496\n",
      "Epoch: 3017 Loss: 0.07097779214382172\n",
      "Epoch: 3018 Loss: 0.07094357162714005\n",
      "Epoch: 3019 Loss: 0.07089189440011978\n",
      "Epoch: 3020 Loss: 0.07086333632469177\n",
      "Epoch: 3021 Loss: 0.07078468054533005\n",
      "Epoch: 3022 Loss: 0.07076036930084229\n",
      "Epoch: 3023 Loss: 0.07071911543607712\n",
      "Epoch: 3024 Loss: 0.07068714499473572\n",
      "Epoch: 3025 Loss: 0.0706150233745575\n",
      "Epoch: 3026 Loss: 0.070588618516922\n",
      "Epoch: 3027 Loss: 0.07052228599786758\n",
      "Epoch: 3028 Loss: 0.0705113336443901\n",
      "Epoch: 3029 Loss: 0.07042329758405685\n",
      "Epoch: 3030 Loss: 0.07037624716758728\n",
      "Epoch: 3031 Loss: 0.07036678493022919\n",
      "Epoch: 3032 Loss: 0.07030972093343735\n",
      "Epoch: 3033 Loss: 0.07023773342370987\n",
      "Epoch: 3034 Loss: 0.07018671184778214\n",
      "Epoch: 3035 Loss: 0.07016800343990326\n",
      "Epoch: 3036 Loss: 0.07011421024799347\n",
      "Epoch: 3037 Loss: 0.07002868503332138\n",
      "Epoch: 3038 Loss: 0.07000771909952164\n",
      "Epoch: 3039 Loss: 0.06995152682065964\n",
      "Epoch: 3040 Loss: 0.06993204355239868\n",
      "Epoch: 3041 Loss: 0.06987301260232925\n",
      "Epoch: 3042 Loss: 0.06982843577861786\n",
      "Epoch: 3043 Loss: 0.06976708024740219\n",
      "Epoch: 3044 Loss: 0.06972885876893997\n",
      "Epoch: 3045 Loss: 0.06969146430492401\n",
      "Epoch: 3046 Loss: 0.06965865939855576\n",
      "Epoch: 3047 Loss: 0.06959518045186996\n",
      "Epoch: 3048 Loss: 0.06954717636108398\n",
      "Epoch: 3049 Loss: 0.06947854906320572\n",
      "Epoch: 3050 Loss: 0.06946047395467758\n",
      "Epoch: 3051 Loss: 0.06939573585987091\n",
      "Epoch: 3052 Loss: 0.06934857368469238\n",
      "Epoch: 3053 Loss: 0.06932202726602554\n",
      "Epoch: 3054 Loss: 0.06925887614488602\n",
      "Epoch: 3055 Loss: 0.06923828274011612\n",
      "Epoch: 3056 Loss: 0.06923309713602066\n",
      "Epoch: 3057 Loss: 0.06911822408437729\n",
      "Epoch: 3058 Loss: 0.0690544992685318\n",
      "Epoch: 3059 Loss: 0.0690235048532486\n",
      "Epoch: 3060 Loss: 0.06899388134479523\n",
      "Epoch: 3061 Loss: 0.06893749535083771\n",
      "Epoch: 3062 Loss: 0.06894019991159439\n",
      "Epoch: 3063 Loss: 0.06884963810443878\n",
      "Epoch: 3064 Loss: 0.06883791834115982\n",
      "Epoch: 3065 Loss: 0.06878983974456787\n",
      "Epoch: 3066 Loss: 0.06871900707483292\n",
      "Epoch: 3067 Loss: 0.06865693628787994\n",
      "Epoch: 3068 Loss: 0.06863461434841156\n",
      "Epoch: 3069 Loss: 0.0685737356543541\n",
      "Epoch: 3070 Loss: 0.06854081153869629\n",
      "Epoch: 3071 Loss: 0.06849220395088196\n",
      "Epoch: 3072 Loss: 0.06844978779554367\n",
      "Epoch: 3073 Loss: 0.06841558963060379\n",
      "Epoch: 3074 Loss: 0.06835474073886871\n",
      "Epoch: 3075 Loss: 0.0683046355843544\n",
      "Epoch: 3076 Loss: 0.0682598277926445\n",
      "Epoch: 3077 Loss: 0.06821176409721375\n",
      "Epoch: 3078 Loss: 0.06820105016231537\n",
      "Epoch: 3079 Loss: 0.0681571289896965\n",
      "Epoch: 3080 Loss: 0.06809846311807632\n",
      "Epoch: 3081 Loss: 0.06802080571651459\n",
      "Epoch: 3082 Loss: 0.06799289584159851\n",
      "Epoch: 3083 Loss: 0.06797339767217636\n",
      "Epoch: 3084 Loss: 0.06793195009231567\n",
      "Epoch: 3085 Loss: 0.06783486157655716\n",
      "Epoch: 3086 Loss: 0.06782965362071991\n",
      "Epoch: 3087 Loss: 0.06776013970375061\n",
      "Epoch: 3088 Loss: 0.06776043772697449\n",
      "Epoch: 3089 Loss: 0.06768231093883514\n",
      "Epoch: 3090 Loss: 0.06766600161790848\n",
      "Epoch: 3091 Loss: 0.06759533286094666\n",
      "Epoch: 3092 Loss: 0.0675688311457634\n",
      "Epoch: 3093 Loss: 0.06748013943433762\n",
      "Epoch: 3094 Loss: 0.06744533032178879\n",
      "Epoch: 3095 Loss: 0.06739960610866547\n",
      "Epoch: 3096 Loss: 0.0673910602927208\n",
      "Epoch: 3097 Loss: 0.06732194870710373\n",
      "Epoch: 3098 Loss: 0.06729334592819214\n",
      "Epoch: 3099 Loss: 0.06725151091814041\n",
      "Epoch: 3100 Loss: 0.06721516698598862\n",
      "Epoch: 3101 Loss: 0.0671413391828537\n",
      "Epoch: 3102 Loss: 0.06712290644645691\n",
      "Epoch: 3103 Loss: 0.06709646433591843\n",
      "Epoch: 3104 Loss: 0.06703559309244156\n",
      "Epoch: 3105 Loss: 0.0669974610209465\n",
      "Epoch: 3106 Loss: 0.06695789843797684\n",
      "Epoch: 3107 Loss: 0.06686442345380783\n",
      "Epoch: 3108 Loss: 0.06684859097003937\n",
      "Epoch: 3109 Loss: 0.06678974628448486\n",
      "Epoch: 3110 Loss: 0.06678077578544617\n",
      "Epoch: 3111 Loss: 0.06670276820659637\n",
      "Epoch: 3112 Loss: 0.06668179482221603\n",
      "Epoch: 3113 Loss: 0.0666467696428299\n",
      "Epoch: 3114 Loss: 0.06659658253192902\n",
      "Epoch: 3115 Loss: 0.06655942648649216\n",
      "Epoch: 3116 Loss: 0.06651114672422409\n",
      "Epoch: 3117 Loss: 0.06642478704452515\n",
      "Epoch: 3118 Loss: 0.06639829277992249\n",
      "Epoch: 3119 Loss: 0.06635875999927521\n",
      "Epoch: 3120 Loss: 0.06637594103813171\n",
      "Epoch: 3121 Loss: 0.06629357486963272\n",
      "Epoch: 3122 Loss: 0.06625597178936005\n",
      "Epoch: 3123 Loss: 0.0661972239613533\n",
      "Epoch: 3124 Loss: 0.06616930663585663\n",
      "Epoch: 3125 Loss: 0.06613051146268845\n",
      "Epoch: 3126 Loss: 0.06606949120759964\n",
      "Epoch: 3127 Loss: 0.06602620333433151\n",
      "Epoch: 3128 Loss: 0.06597311049699783\n",
      "Epoch: 3129 Loss: 0.0659351497888565\n",
      "Epoch: 3130 Loss: 0.06590007245540619\n",
      "Epoch: 3131 Loss: 0.06584278494119644\n",
      "Epoch: 3132 Loss: 0.06580525636672974\n",
      "Epoch: 3133 Loss: 0.0657561719417572\n",
      "Epoch: 3134 Loss: 0.06575330346822739\n",
      "Epoch: 3135 Loss: 0.06566336750984192\n",
      "Epoch: 3136 Loss: 0.0656551793217659\n",
      "Epoch: 3137 Loss: 0.06559429317712784\n",
      "Epoch: 3138 Loss: 0.06558813899755478\n",
      "Epoch: 3139 Loss: 0.06551221758127213\n",
      "Epoch: 3140 Loss: 0.06551975756883621\n",
      "Epoch: 3141 Loss: 0.06541385501623154\n",
      "Epoch: 3142 Loss: 0.06539083272218704\n",
      "Epoch: 3143 Loss: 0.06532826274633408\n",
      "Epoch: 3144 Loss: 0.06529819965362549\n",
      "Epoch: 3145 Loss: 0.06526989489793777\n",
      "Epoch: 3146 Loss: 0.06523237377405167\n",
      "Epoch: 3147 Loss: 0.06515205651521683\n",
      "Epoch: 3148 Loss: 0.06513655185699463\n",
      "Epoch: 3149 Loss: 0.0650710016489029\n",
      "Epoch: 3150 Loss: 0.06506512314081192\n",
      "Epoch: 3151 Loss: 0.06504372507333755\n",
      "Epoch: 3152 Loss: 0.06495758891105652\n",
      "Epoch: 3153 Loss: 0.0649050921201706\n",
      "Epoch: 3154 Loss: 0.06487539410591125\n",
      "Epoch: 3155 Loss: 0.0648222267627716\n",
      "Epoch: 3156 Loss: 0.06481432914733887\n",
      "Epoch: 3157 Loss: 0.06473135948181152\n",
      "Epoch: 3158 Loss: 0.06474234163761139\n",
      "Epoch: 3159 Loss: 0.06464312225580215\n",
      "Epoch: 3160 Loss: 0.06464138627052307\n",
      "Epoch: 3161 Loss: 0.06458873301744461\n",
      "Epoch: 3162 Loss: 0.06453760713338852\n",
      "Epoch: 3163 Loss: 0.06448323279619217\n",
      "Epoch: 3164 Loss: 0.06446753442287445\n",
      "Epoch: 3165 Loss: 0.06441806256771088\n",
      "Epoch: 3166 Loss: 0.06439555436372757\n",
      "Epoch: 3167 Loss: 0.06430815905332565\n",
      "Epoch: 3168 Loss: 0.06428030133247375\n",
      "Epoch: 3169 Loss: 0.06424832344055176\n",
      "Epoch: 3170 Loss: 0.06421305239200592\n",
      "Epoch: 3171 Loss: 0.06419144570827484\n",
      "Epoch: 3172 Loss: 0.06414628028869629\n",
      "Epoch: 3173 Loss: 0.06407809257507324\n",
      "Epoch: 3174 Loss: 0.06403935700654984\n",
      "Epoch: 3175 Loss: 0.06399288028478622\n",
      "Epoch: 3176 Loss: 0.06395760178565979\n",
      "Epoch: 3177 Loss: 0.06391431391239166\n",
      "Epoch: 3178 Loss: 0.06386595964431763\n",
      "Epoch: 3179 Loss: 0.06385219097137451\n",
      "Epoch: 3180 Loss: 0.0638035461306572\n",
      "Epoch: 3181 Loss: 0.06374334543943405\n",
      "Epoch: 3182 Loss: 0.06373567879199982\n",
      "Epoch: 3183 Loss: 0.06369476020336151\n",
      "Epoch: 3184 Loss: 0.06364139169454575\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 3185 Loss: 0.06357269734144211\n",
      "Epoch: 3186 Loss: 0.06353369355201721\n",
      "Epoch: 3187 Loss: 0.06349340081214905\n",
      "Epoch: 3188 Loss: 0.06345989555120468\n",
      "Epoch: 3189 Loss: 0.06341288238763809\n",
      "Epoch: 3190 Loss: 0.0634339302778244\n",
      "Epoch: 3191 Loss: 0.0633382797241211\n",
      "Epoch: 3192 Loss: 0.06330052018165588\n",
      "Epoch: 3193 Loss: 0.06323433667421341\n",
      "Epoch: 3194 Loss: 0.06322506815195084\n",
      "Epoch: 3195 Loss: 0.06316656619310379\n",
      "Epoch: 3196 Loss: 0.06319136172533035\n",
      "Epoch: 3197 Loss: 0.06311500072479248\n",
      "Epoch: 3198 Loss: 0.0630607008934021\n",
      "Epoch: 3199 Loss: 0.06301185488700867\n",
      "Epoch: 3200 Loss: 0.06297914683818817\n",
      "Epoch: 3201 Loss: 0.06294315308332443\n",
      "Epoch: 3202 Loss: 0.06290165334939957\n",
      "Epoch: 3203 Loss: 0.06283605098724365\n",
      "Epoch: 3204 Loss: 0.06281360238790512\n",
      "Epoch: 3205 Loss: 0.06276153773069382\n",
      "Epoch: 3206 Loss: 0.06276555359363556\n",
      "Epoch: 3207 Loss: 0.06268615275621414\n",
      "Epoch: 3208 Loss: 0.06269682198762894\n",
      "Epoch: 3209 Loss: 0.06260818988084793\n",
      "Epoch: 3210 Loss: 0.0625920370221138\n",
      "Epoch: 3211 Loss: 0.06252133101224899\n",
      "Epoch: 3212 Loss: 0.06251241266727448\n",
      "Epoch: 3213 Loss: 0.062432464212179184\n",
      "Epoch: 3214 Loss: 0.062417007982730865\n",
      "Epoch: 3215 Loss: 0.062371835112571716\n",
      "Epoch: 3216 Loss: 0.06237493455410004\n",
      "Epoch: 3217 Loss: 0.062307000160217285\n",
      "Epoch: 3218 Loss: 0.062254905700683594\n",
      "Epoch: 3219 Loss: 0.06219293549656868\n",
      "Epoch: 3220 Loss: 0.06220276281237602\n",
      "Epoch: 3221 Loss: 0.06213751807808876\n",
      "Epoch: 3222 Loss: 0.062124498188495636\n",
      "Epoch: 3223 Loss: 0.062042929232120514\n",
      "Epoch: 3224 Loss: 0.062008049339056015\n",
      "Epoch: 3225 Loss: 0.061986226588487625\n",
      "Epoch: 3226 Loss: 0.061970196664333344\n",
      "Epoch: 3227 Loss: 0.06191094219684601\n",
      "Epoch: 3228 Loss: 0.06186378002166748\n",
      "Epoch: 3229 Loss: 0.061831291764974594\n",
      "Epoch: 3230 Loss: 0.061785176396369934\n",
      "Epoch: 3231 Loss: 0.061731018126010895\n",
      "Epoch: 3232 Loss: 0.06171592324972153\n",
      "Epoch: 3233 Loss: 0.06166934221982956\n",
      "Epoch: 3234 Loss: 0.061621349304914474\n",
      "Epoch: 3235 Loss: 0.06158357858657837\n",
      "Epoch: 3236 Loss: 0.06157560274004936\n",
      "Epoch: 3237 Loss: 0.06150490790605545\n",
      "Epoch: 3238 Loss: 0.061489254236221313\n",
      "Epoch: 3239 Loss: 0.061431389302015305\n",
      "Epoch: 3240 Loss: 0.06141085922718048\n",
      "Epoch: 3241 Loss: 0.061372917145490646\n",
      "Epoch: 3242 Loss: 0.06131478399038315\n",
      "Epoch: 3243 Loss: 0.06126180291175842\n",
      "Epoch: 3244 Loss: 0.06125183403491974\n",
      "Epoch: 3245 Loss: 0.06118201091885567\n",
      "Epoch: 3246 Loss: 0.06116895750164986\n",
      "Epoch: 3247 Loss: 0.06111764907836914\n",
      "Epoch: 3248 Loss: 0.06108127161860466\n",
      "Epoch: 3249 Loss: 0.061033111065626144\n",
      "Epoch: 3250 Loss: 0.061065495014190674\n",
      "Epoch: 3251 Loss: 0.06098974868655205\n",
      "Epoch: 3252 Loss: 0.06092328950762749\n",
      "Epoch: 3253 Loss: 0.06089543551206589\n",
      "Epoch: 3254 Loss: 0.06087848171591759\n",
      "Epoch: 3255 Loss: 0.06083128601312637\n",
      "Epoch: 3256 Loss: 0.06079192832112312\n",
      "Epoch: 3257 Loss: 0.06074611097574234\n",
      "Epoch: 3258 Loss: 0.06072186678647995\n",
      "Epoch: 3259 Loss: 0.060666102916002274\n",
      "Epoch: 3260 Loss: 0.06061571463942528\n",
      "Epoch: 3261 Loss: 0.06057179719209671\n",
      "Epoch: 3262 Loss: 0.06056761369109154\n",
      "Epoch: 3263 Loss: 0.06048830971121788\n",
      "Epoch: 3264 Loss: 0.06046150252223015\n",
      "Epoch: 3265 Loss: 0.06044334918260574\n",
      "Epoch: 3266 Loss: 0.060439012944698334\n",
      "Epoch: 3267 Loss: 0.06039276719093323\n",
      "Epoch: 3268 Loss: 0.06035042926669121\n",
      "Epoch: 3269 Loss: 0.06028008088469505\n",
      "Epoch: 3270 Loss: 0.06026878580451012\n",
      "Epoch: 3271 Loss: 0.060193318873643875\n",
      "Epoch: 3272 Loss: 0.06017887964844704\n",
      "Epoch: 3273 Loss: 0.06013591215014458\n",
      "Epoch: 3274 Loss: 0.0601126030087471\n",
      "Epoch: 3275 Loss: 0.06004844605922699\n",
      "Epoch: 3276 Loss: 0.060030218213796616\n",
      "Epoch: 3277 Loss: 0.05996986851096153\n",
      "Epoch: 3278 Loss: 0.05997911095619202\n",
      "Epoch: 3279 Loss: 0.05990936607122421\n",
      "Epoch: 3280 Loss: 0.059879619628190994\n",
      "Epoch: 3281 Loss: 0.05982145294547081\n",
      "Epoch: 3282 Loss: 0.05981264263391495\n",
      "Epoch: 3283 Loss: 0.05977141857147217\n",
      "Epoch: 3284 Loss: 0.05972179397940636\n",
      "Epoch: 3285 Loss: 0.05968870222568512\n",
      "Epoch: 3286 Loss: 0.05964513123035431\n",
      "Epoch: 3287 Loss: 0.059618134051561356\n",
      "Epoch: 3288 Loss: 0.05960185080766678\n",
      "Epoch: 3289 Loss: 0.05955643579363823\n",
      "Epoch: 3290 Loss: 0.059519216418266296\n",
      "Epoch: 3291 Loss: 0.05947588011622429\n",
      "Epoch: 3292 Loss: 0.059430722147226334\n",
      "Epoch: 3293 Loss: 0.05938097462058067\n",
      "Epoch: 3294 Loss: 0.0593789666891098\n",
      "Epoch: 3295 Loss: 0.05930008739233017\n",
      "Epoch: 3296 Loss: 0.05928564444184303\n",
      "Epoch: 3297 Loss: 0.05923639237880707\n",
      "Epoch: 3298 Loss: 0.05921991169452667\n",
      "Epoch: 3299 Loss: 0.05919774994254112\n",
      "Epoch: 3300 Loss: 0.05914466455578804\n",
      "Epoch: 3301 Loss: 0.059088338166475296\n",
      "Epoch: 3302 Loss: 0.05904748663306236\n",
      "Epoch: 3303 Loss: 0.05903233587741852\n",
      "Epoch: 3304 Loss: 0.05898714438080788\n",
      "Epoch: 3305 Loss: 0.05893644317984581\n",
      "Epoch: 3306 Loss: 0.05894331634044647\n",
      "Epoch: 3307 Loss: 0.05886730179190636\n",
      "Epoch: 3308 Loss: 0.05886242911219597\n",
      "Epoch: 3309 Loss: 0.05881306901574135\n",
      "Epoch: 3310 Loss: 0.058795180171728134\n",
      "Epoch: 3311 Loss: 0.05873514339327812\n",
      "Epoch: 3312 Loss: 0.05868292599916458\n",
      "Epoch: 3313 Loss: 0.058664098381996155\n",
      "Epoch: 3314 Loss: 0.05864277854561806\n",
      "Epoch: 3315 Loss: 0.058587316423654556\n",
      "Epoch: 3316 Loss: 0.05856044590473175\n",
      "Epoch: 3317 Loss: 0.058509550988674164\n",
      "Epoch: 3318 Loss: 0.05850648880004883\n",
      "Epoch: 3319 Loss: 0.05847584456205368\n",
      "Epoch: 3320 Loss: 0.05844089388847351\n",
      "Epoch: 3321 Loss: 0.058353930711746216\n",
      "Epoch: 3322 Loss: 0.05834655836224556\n",
      "Epoch: 3323 Loss: 0.058281902223825455\n",
      "Epoch: 3324 Loss: 0.05826737359166145\n",
      "Epoch: 3325 Loss: 0.058231331408023834\n",
      "Epoch: 3326 Loss: 0.058214861899614334\n",
      "Epoch: 3327 Loss: 0.058153245598077774\n",
      "Epoch: 3328 Loss: 0.058135442435741425\n",
      "Epoch: 3329 Loss: 0.058069270104169846\n",
      "Epoch: 3330 Loss: 0.05804704874753952\n",
      "Epoch: 3331 Loss: 0.05800910294055939\n",
      "Epoch: 3332 Loss: 0.058013252913951874\n",
      "Epoch: 3333 Loss: 0.05793751776218414\n",
      "Epoch: 3334 Loss: 0.05792651325464249\n",
      "Epoch: 3335 Loss: 0.0578775592148304\n",
      "Epoch: 3336 Loss: 0.05784284323453903\n",
      "Epoch: 3337 Loss: 0.057802360504865646\n",
      "Epoch: 3338 Loss: 0.057763151824474335\n",
      "Epoch: 3339 Loss: 0.057727307081222534\n",
      "Epoch: 3340 Loss: 0.05771728977560997\n",
      "Epoch: 3341 Loss: 0.057660482823848724\n",
      "Epoch: 3342 Loss: 0.057641927152872086\n",
      "Epoch: 3343 Loss: 0.057600971311330795\n",
      "Epoch: 3344 Loss: 0.0575639009475708\n",
      "Epoch: 3345 Loss: 0.057520341128110886\n",
      "Epoch: 3346 Loss: 0.05747240409255028\n",
      "Epoch: 3347 Loss: 0.057426728308200836\n",
      "Epoch: 3348 Loss: 0.0574183315038681\n",
      "Epoch: 3349 Loss: 0.05739198625087738\n",
      "Epoch: 3350 Loss: 0.05735669657588005\n",
      "Epoch: 3351 Loss: 0.0573163665831089\n",
      "Epoch: 3352 Loss: 0.0572766438126564\n",
      "Epoch: 3353 Loss: 0.057234812527894974\n",
      "Epoch: 3354 Loss: 0.057223379611968994\n",
      "Epoch: 3355 Loss: 0.05716249719262123\n",
      "Epoch: 3356 Loss: 0.05716244876384735\n",
      "Epoch: 3357 Loss: 0.057105645537376404\n",
      "Epoch: 3358 Loss: 0.057071421295404434\n",
      "Epoch: 3359 Loss: 0.05701873078942299\n",
      "Epoch: 3360 Loss: 0.057002920657396317\n",
      "Epoch: 3361 Loss: 0.05695484206080437\n",
      "Epoch: 3362 Loss: 0.05694461613893509\n",
      "Epoch: 3363 Loss: 0.05688447132706642\n",
      "Epoch: 3364 Loss: 0.05688656494021416\n",
      "Epoch: 3365 Loss: 0.05681869760155678\n",
      "Epoch: 3366 Loss: 0.05682683363556862\n",
      "Epoch: 3367 Loss: 0.05674222856760025\n",
      "Epoch: 3368 Loss: 0.056756168603897095\n",
      "Epoch: 3369 Loss: 0.05668303370475769\n",
      "Epoch: 3370 Loss: 0.056660912930965424\n",
      "Epoch: 3371 Loss: 0.05659537762403488\n",
      "Epoch: 3372 Loss: 0.05657651647925377\n",
      "Epoch: 3373 Loss: 0.056559767574071884\n",
      "Epoch: 3374 Loss: 0.05653161182999611\n",
      "Epoch: 3375 Loss: 0.05647260323166847\n",
      "Epoch: 3376 Loss: 0.05644860863685608\n",
      "Epoch: 3377 Loss: 0.05640244483947754\n",
      "Epoch: 3378 Loss: 0.05638465657830238\n",
      "Epoch: 3379 Loss: 0.056331392377614975\n",
      "Epoch: 3380 Loss: 0.05633570998907089\n",
      "Epoch: 3381 Loss: 0.05628124251961708\n",
      "Epoch: 3382 Loss: 0.05627455189824104\n",
      "Epoch: 3383 Loss: 0.056198883801698685\n",
      "Epoch: 3384 Loss: 0.05619456619024277\n",
      "Epoch: 3385 Loss: 0.05615656077861786\n",
      "Epoch: 3386 Loss: 0.056132908910512924\n",
      "Epoch: 3387 Loss: 0.05605945363640785\n",
      "Epoch: 3388 Loss: 0.05606815963983536\n",
      "Epoch: 3389 Loss: 0.05599013343453407\n",
      "Epoch: 3390 Loss: 0.05598239600658417\n",
      "Epoch: 3391 Loss: 0.05592656880617142\n",
      "Epoch: 3392 Loss: 0.05591076985001564\n",
      "Epoch: 3393 Loss: 0.0558663047850132\n",
      "Epoch: 3394 Loss: 0.055833373218774796\n",
      "Epoch: 3395 Loss: 0.05579991266131401\n",
      "Epoch: 3396 Loss: 0.0557895265519619\n",
      "Epoch: 3397 Loss: 0.055749114602804184\n",
      "Epoch: 3398 Loss: 0.05570654943585396\n",
      "Epoch: 3399 Loss: 0.05563794821500778\n",
      "Epoch: 3400 Loss: 0.055648911744356155\n",
      "Epoch: 3401 Loss: 0.05559155344963074\n",
      "Epoch: 3402 Loss: 0.055591195821762085\n",
      "Epoch: 3403 Loss: 0.05553167685866356\n",
      "Epoch: 3404 Loss: 0.055515699088573456\n",
      "Epoch: 3405 Loss: 0.05548751726746559\n",
      "Epoch: 3406 Loss: 0.05543481931090355\n",
      "Epoch: 3407 Loss: 0.055390797555446625\n",
      "Epoch: 3408 Loss: 0.055396001785993576\n",
      "Epoch: 3409 Loss: 0.055343903601169586\n",
      "Epoch: 3410 Loss: 0.055346302688121796\n",
      "Epoch: 3411 Loss: 0.055263716727495193\n",
      "Epoch: 3412 Loss: 0.055231861770153046\n",
      "Epoch: 3413 Loss: 0.05520423501729965\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 3414 Loss: 0.05518385395407677\n",
      "Epoch: 3415 Loss: 0.05515581741929054\n",
      "Epoch: 3416 Loss: 0.05510089918971062\n",
      "Epoch: 3417 Loss: 0.05507148429751396\n",
      "Epoch: 3418 Loss: 0.055032409727573395\n",
      "Epoch: 3419 Loss: 0.055004190653562546\n",
      "Epoch: 3420 Loss: 0.05498422682285309\n",
      "Epoch: 3421 Loss: 0.05494361370801926\n",
      "Epoch: 3422 Loss: 0.05491350591182709\n",
      "Epoch: 3423 Loss: 0.05486806482076645\n",
      "Epoch: 3424 Loss: 0.054837729781866074\n",
      "Epoch: 3425 Loss: 0.05482647940516472\n",
      "Epoch: 3426 Loss: 0.05478033423423767\n",
      "Epoch: 3427 Loss: 0.05474850907921791\n",
      "Epoch: 3428 Loss: 0.05471814051270485\n",
      "Epoch: 3429 Loss: 0.054683517664670944\n",
      "Epoch: 3430 Loss: 0.05466316640377045\n",
      "Epoch: 3431 Loss: 0.054595593363046646\n",
      "Epoch: 3432 Loss: 0.054579559713602066\n",
      "Epoch: 3433 Loss: 0.054536182433366776\n",
      "Epoch: 3434 Loss: 0.05453020706772804\n",
      "Epoch: 3435 Loss: 0.05447082221508026\n",
      "Epoch: 3436 Loss: 0.05447249487042427\n",
      "Epoch: 3437 Loss: 0.05440276488661766\n",
      "Epoch: 3438 Loss: 0.05441807582974434\n",
      "Epoch: 3439 Loss: 0.05435502529144287\n",
      "Epoch: 3440 Loss: 0.05433781072497368\n",
      "Epoch: 3441 Loss: 0.05428415909409523\n",
      "Epoch: 3442 Loss: 0.054273687303066254\n",
      "Epoch: 3443 Loss: 0.05421901121735573\n",
      "Epoch: 3444 Loss: 0.054190900176763535\n",
      "Epoch: 3445 Loss: 0.05415476858615875\n",
      "Epoch: 3446 Loss: 0.054151181131601334\n",
      "Epoch: 3447 Loss: 0.05409334972500801\n",
      "Epoch: 3448 Loss: 0.05407136306166649\n",
      "Epoch: 3449 Loss: 0.05402374267578125\n",
      "Epoch: 3450 Loss: 0.05401148274540901\n",
      "Epoch: 3451 Loss: 0.053959865123033524\n",
      "Epoch: 3452 Loss: 0.05395049974322319\n",
      "Epoch: 3453 Loss: 0.053880106657743454\n",
      "Epoch: 3454 Loss: 0.05386859178543091\n",
      "Epoch: 3455 Loss: 0.05383700877428055\n",
      "Epoch: 3456 Loss: 0.05382174253463745\n",
      "Epoch: 3457 Loss: 0.053791843354701996\n",
      "Epoch: 3458 Loss: 0.053790099918842316\n",
      "Epoch: 3459 Loss: 0.05370165407657623\n",
      "Epoch: 3460 Loss: 0.05371124669909477\n",
      "Epoch: 3461 Loss: 0.05364339426159859\n",
      "Epoch: 3462 Loss: 0.05361243337392807\n",
      "Epoch: 3463 Loss: 0.053580984473228455\n",
      "Epoch: 3464 Loss: 0.05354846268892288\n",
      "Epoch: 3465 Loss: 0.0535048209130764\n",
      "Epoch: 3466 Loss: 0.05350146070122719\n",
      "Epoch: 3467 Loss: 0.05347049981355667\n",
      "Epoch: 3468 Loss: 0.053437311202287674\n",
      "Epoch: 3469 Loss: 0.05340546742081642\n",
      "Epoch: 3470 Loss: 0.0533791184425354\n",
      "Epoch: 3471 Loss: 0.0533120296895504\n",
      "Epoch: 3472 Loss: 0.05331496521830559\n",
      "Epoch: 3473 Loss: 0.05327082425355911\n",
      "Epoch: 3474 Loss: 0.05324846878647804\n",
      "Epoch: 3475 Loss: 0.05322248116135597\n",
      "Epoch: 3476 Loss: 0.05317804217338562\n",
      "Epoch: 3477 Loss: 0.05313737690448761\n",
      "Epoch: 3478 Loss: 0.053133171051740646\n",
      "Epoch: 3479 Loss: 0.05308035388588905\n",
      "Epoch: 3480 Loss: 0.05305343493819237\n",
      "Epoch: 3481 Loss: 0.05303172022104263\n",
      "Epoch: 3482 Loss: 0.052994344383478165\n",
      "Epoch: 3483 Loss: 0.05297558382153511\n",
      "Epoch: 3484 Loss: 0.052944764494895935\n",
      "Epoch: 3485 Loss: 0.05288867652416229\n",
      "Epoch: 3486 Loss: 0.052843596786260605\n",
      "Epoch: 3487 Loss: 0.05282924324274063\n",
      "Epoch: 3488 Loss: 0.05282086879014969\n",
      "Epoch: 3489 Loss: 0.052760787308216095\n",
      "Epoch: 3490 Loss: 0.0527583546936512\n",
      "Epoch: 3491 Loss: 0.05270300433039665\n",
      "Epoch: 3492 Loss: 0.05267634615302086\n",
      "Epoch: 3493 Loss: 0.052643612027168274\n",
      "Epoch: 3494 Loss: 0.0526510551571846\n",
      "Epoch: 3495 Loss: 0.052585285156965256\n",
      "Epoch: 3496 Loss: 0.05254911258816719\n",
      "Epoch: 3497 Loss: 0.05251482501626015\n",
      "Epoch: 3498 Loss: 0.05250653997063637\n",
      "Epoch: 3499 Loss: 0.052463795989751816\n",
      "Epoch: 3500 Loss: 0.05244041234254837\n",
      "Epoch: 3501 Loss: 0.052418433129787445\n",
      "Epoch: 3502 Loss: 0.05237872526049614\n",
      "Epoch: 3503 Loss: 0.05234377831220627\n",
      "Epoch: 3504 Loss: 0.05234122276306152\n",
      "Epoch: 3505 Loss: 0.052277691662311554\n",
      "Epoch: 3506 Loss: 0.052280619740486145\n",
      "Epoch: 3507 Loss: 0.0522184781730175\n",
      "Epoch: 3508 Loss: 0.05219171196222305\n",
      "Epoch: 3509 Loss: 0.05215653404593468\n",
      "Epoch: 3510 Loss: 0.052134282886981964\n",
      "Epoch: 3511 Loss: 0.052102115005254745\n",
      "Epoch: 3512 Loss: 0.052074894309043884\n",
      "Epoch: 3513 Loss: 0.05201118439435959\n",
      "Epoch: 3514 Loss: 0.05202001333236694\n",
      "Epoch: 3515 Loss: 0.05196477472782135\n",
      "Epoch: 3516 Loss: 0.051938001066446304\n",
      "Epoch: 3517 Loss: 0.05189696326851845\n",
      "Epoch: 3518 Loss: 0.05189301446080208\n",
      "Epoch: 3519 Loss: 0.05185754597187042\n",
      "Epoch: 3520 Loss: 0.0518234558403492\n",
      "Epoch: 3521 Loss: 0.051791612058877945\n",
      "Epoch: 3522 Loss: 0.05178077891469002\n",
      "Epoch: 3523 Loss: 0.05172869190573692\n",
      "Epoch: 3524 Loss: 0.051704343408346176\n",
      "Epoch: 3525 Loss: 0.05169496312737465\n",
      "Epoch: 3526 Loss: 0.051661618053913116\n",
      "Epoch: 3527 Loss: 0.05163272097706795\n",
      "Epoch: 3528 Loss: 0.051595788449048996\n",
      "Epoch: 3529 Loss: 0.05154598876833916\n",
      "Epoch: 3530 Loss: 0.05151526257395744\n",
      "Epoch: 3531 Loss: 0.051487747579813004\n",
      "Epoch: 3532 Loss: 0.05147644504904747\n",
      "Epoch: 3533 Loss: 0.051423490047454834\n",
      "Epoch: 3534 Loss: 0.05141093209385872\n",
      "Epoch: 3535 Loss: 0.05135224014520645\n",
      "Epoch: 3536 Loss: 0.05137848109006882\n",
      "Epoch: 3537 Loss: 0.05131442844867706\n",
      "Epoch: 3538 Loss: 0.051283933222293854\n",
      "Epoch: 3539 Loss: 0.05125226825475693\n",
      "Epoch: 3540 Loss: 0.051230914890766144\n",
      "Epoch: 3541 Loss: 0.051194265484809875\n",
      "Epoch: 3542 Loss: 0.05121161416172981\n",
      "Epoch: 3543 Loss: 0.05113750323653221\n",
      "Epoch: 3544 Loss: 0.05112072825431824\n",
      "Epoch: 3545 Loss: 0.05106585845351219\n",
      "Epoch: 3546 Loss: 0.051057107746601105\n",
      "Epoch: 3547 Loss: 0.05100481957197189\n",
      "Epoch: 3548 Loss: 0.05100404843688011\n",
      "Epoch: 3549 Loss: 0.05095357447862625\n",
      "Epoch: 3550 Loss: 0.05095967650413513\n",
      "Epoch: 3551 Loss: 0.050915997475385666\n",
      "Epoch: 3552 Loss: 0.050891075283288956\n",
      "Epoch: 3553 Loss: 0.05082890763878822\n",
      "Epoch: 3554 Loss: 0.05080520734190941\n",
      "Epoch: 3555 Loss: 0.05077783390879631\n",
      "Epoch: 3556 Loss: 0.05076297000050545\n",
      "Epoch: 3557 Loss: 0.050734780728816986\n",
      "Epoch: 3558 Loss: 0.050696928054094315\n",
      "Epoch: 3559 Loss: 0.050684016197919846\n",
      "Epoch: 3560 Loss: 0.05065213888883591\n",
      "Epoch: 3561 Loss: 0.0506155863404274\n",
      "Epoch: 3562 Loss: 0.05057939514517784\n",
      "Epoch: 3563 Loss: 0.050567273050546646\n",
      "Epoch: 3564 Loss: 0.050530582666397095\n",
      "Epoch: 3565 Loss: 0.05050809308886528\n",
      "Epoch: 3566 Loss: 0.05045768618583679\n",
      "Epoch: 3567 Loss: 0.05043056607246399\n",
      "Epoch: 3568 Loss: 0.05040130019187927\n",
      "Epoch: 3569 Loss: 0.050372879952192307\n",
      "Epoch: 3570 Loss: 0.050347406417131424\n",
      "Epoch: 3571 Loss: 0.050327103585004807\n",
      "Epoch: 3572 Loss: 0.050312209874391556\n",
      "Epoch: 3573 Loss: 0.05026612803339958\n",
      "Epoch: 3574 Loss: 0.05025063082575798\n",
      "Epoch: 3575 Loss: 0.05021345242857933\n",
      "Epoch: 3576 Loss: 0.05017717182636261\n",
      "Epoch: 3577 Loss: 0.050139445811510086\n",
      "Epoch: 3578 Loss: 0.050131168216466904\n",
      "Epoch: 3579 Loss: 0.05007652938365936\n",
      "Epoch: 3580 Loss: 0.05007331818342209\n",
      "Epoch: 3581 Loss: 0.050024908035993576\n",
      "Epoch: 3582 Loss: 0.05000894516706467\n",
      "Epoch: 3583 Loss: 0.049971237778663635\n",
      "Epoch: 3584 Loss: 0.04995604231953621\n",
      "Epoch: 3585 Loss: 0.0499209389090538\n",
      "Epoch: 3586 Loss: 0.04989796504378319\n",
      "Epoch: 3587 Loss: 0.049868904054164886\n",
      "Epoch: 3588 Loss: 0.04984298348426819\n",
      "Epoch: 3589 Loss: 0.049796488136053085\n",
      "Epoch: 3590 Loss: 0.04978194832801819\n",
      "Epoch: 3591 Loss: 0.0497467927634716\n",
      "Epoch: 3592 Loss: 0.04972881078720093\n",
      "Epoch: 3593 Loss: 0.04968840256333351\n",
      "Epoch: 3594 Loss: 0.04968957602977753\n",
      "Epoch: 3595 Loss: 0.049639977514743805\n",
      "Epoch: 3596 Loss: 0.04962172359228134\n",
      "Epoch: 3597 Loss: 0.04957414045929909\n",
      "Epoch: 3598 Loss: 0.04955971613526344\n",
      "Epoch: 3599 Loss: 0.04952079802751541\n",
      "Epoch: 3600 Loss: 0.049505338072776794\n",
      "Epoch: 3601 Loss: 0.04947277531027794\n",
      "Epoch: 3602 Loss: 0.04946595057845116\n",
      "Epoch: 3603 Loss: 0.049410417675971985\n",
      "Epoch: 3604 Loss: 0.049404650926589966\n",
      "Epoch: 3605 Loss: 0.04937617853283882\n",
      "Epoch: 3606 Loss: 0.04934030398726463\n",
      "Epoch: 3607 Loss: 0.04930409789085388\n",
      "Epoch: 3608 Loss: 0.04930046200752258\n",
      "Epoch: 3609 Loss: 0.04923976585268974\n",
      "Epoch: 3610 Loss: 0.04921631142497063\n",
      "Epoch: 3611 Loss: 0.049187127500772476\n",
      "Epoch: 3612 Loss: 0.04916967824101448\n",
      "Epoch: 3613 Loss: 0.04912184551358223\n",
      "Epoch: 3614 Loss: 0.049103111028671265\n",
      "Epoch: 3615 Loss: 0.04907674714922905\n",
      "Epoch: 3616 Loss: 0.049055419862270355\n",
      "Epoch: 3617 Loss: 0.04903252422809601\n",
      "Epoch: 3618 Loss: 0.049030158668756485\n",
      "Epoch: 3619 Loss: 0.048985205590724945\n",
      "Epoch: 3620 Loss: 0.048955366015434265\n",
      "Epoch: 3621 Loss: 0.04890984296798706\n",
      "Epoch: 3622 Loss: 0.048905689269304276\n",
      "Epoch: 3623 Loss: 0.04886305332183838\n",
      "Epoch: 3624 Loss: 0.048835232853889465\n",
      "Epoch: 3625 Loss: 0.04881484434008598\n",
      "Epoch: 3626 Loss: 0.04879707098007202\n",
      "Epoch: 3627 Loss: 0.04874207824468613\n",
      "Epoch: 3628 Loss: 0.04872548207640648\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 3629 Loss: 0.04869323596358299\n",
      "Epoch: 3630 Loss: 0.04866374284029007\n",
      "Epoch: 3631 Loss: 0.04861191287636757\n",
      "Epoch: 3632 Loss: 0.0486217699944973\n",
      "Epoch: 3633 Loss: 0.04858650639653206\n",
      "Epoch: 3634 Loss: 0.048589807003736496\n",
      "Epoch: 3635 Loss: 0.0485529787838459\n",
      "Epoch: 3636 Loss: 0.0485168918967247\n",
      "Epoch: 3637 Loss: 0.04846867918968201\n",
      "Epoch: 3638 Loss: 0.04847374185919762\n",
      "Epoch: 3639 Loss: 0.04844004288315773\n",
      "Epoch: 3640 Loss: 0.04842530936002731\n",
      "Epoch: 3641 Loss: 0.048388298600912094\n",
      "Epoch: 3642 Loss: 0.0483526811003685\n",
      "Epoch: 3643 Loss: 0.04832077398896217\n",
      "Epoch: 3644 Loss: 0.04829636588692665\n",
      "Epoch: 3645 Loss: 0.048253320157527924\n",
      "Epoch: 3646 Loss: 0.048239342868328094\n",
      "Epoch: 3647 Loss: 0.048190582543611526\n",
      "Epoch: 3648 Loss: 0.048195064067840576\n",
      "Epoch: 3649 Loss: 0.048155881464481354\n",
      "Epoch: 3650 Loss: 0.0481334887444973\n",
      "Epoch: 3651 Loss: 0.048104576766490936\n",
      "Epoch: 3652 Loss: 0.04809153452515602\n",
      "Epoch: 3653 Loss: 0.04806551709771156\n",
      "Epoch: 3654 Loss: 0.04802782088518143\n",
      "Epoch: 3655 Loss: 0.04801391065120697\n",
      "Epoch: 3656 Loss: 0.04797978326678276\n",
      "Epoch: 3657 Loss: 0.04797104373574257\n",
      "Epoch: 3658 Loss: 0.04794875159859657\n",
      "Epoch: 3659 Loss: 0.04786994680762291\n",
      "Epoch: 3660 Loss: 0.04784686490893364\n",
      "Epoch: 3661 Loss: 0.04782173037528992\n",
      "Epoch: 3662 Loss: 0.04781850799918175\n",
      "Epoch: 3663 Loss: 0.047786369919776917\n",
      "Epoch: 3664 Loss: 0.04777321591973305\n",
      "Epoch: 3665 Loss: 0.047733720391988754\n",
      "Epoch: 3666 Loss: 0.04771918058395386\n",
      "Epoch: 3667 Loss: 0.047689057886600494\n",
      "Epoch: 3668 Loss: 0.04766952991485596\n",
      "Epoch: 3669 Loss: 0.0476405993103981\n",
      "Epoch: 3670 Loss: 0.04760429263114929\n",
      "Epoch: 3671 Loss: 0.04756413400173187\n",
      "Epoch: 3672 Loss: 0.047551434487104416\n",
      "Epoch: 3673 Loss: 0.047530192881822586\n",
      "Epoch: 3674 Loss: 0.04748920351266861\n",
      "Epoch: 3675 Loss: 0.04746335372328758\n",
      "Epoch: 3676 Loss: 0.04744136705994606\n",
      "Epoch: 3677 Loss: 0.04741412773728371\n",
      "Epoch: 3678 Loss: 0.04740610346198082\n",
      "Epoch: 3679 Loss: 0.047370076179504395\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ERROR:root:Internal Python error in the inspect module.\n",
      "Below is the traceback from this internal error.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2862, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"<ipython-input-31-c2cf5b37f087>\", line 9, in <module>\n",
      "    y_pred = model(batchX)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py\", line 489, in __call__\n",
      "    result = self.forward(*input, **kwargs)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/container.py\", line 92, in forward\n",
      "    input = module(input)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py\", line 489, in __call__\n",
      "    result = self.forward(*input, **kwargs)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/linear.py\", line 67, in forward\n",
      "    return F.linear(input, self.weight, self.bias)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/functional.py\", line 1352, in linear\n",
      "    ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())\n",
      "KeyboardInterrupt\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 1806, in showtraceback\n",
      "    stb = value._render_traceback_()\n",
      "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/IPython/core/ultratb.py\", line 1090, in get_records\n",
      "    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/IPython/core/ultratb.py\", line 311, in wrapped\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/IPython/core/ultratb.py\", line 345, in _fixed_getinnerframes\n",
      "    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/inspect.py\", line 1453, in getinnerframes\n",
      "    frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/inspect.py\", line 1415, in getframeinfo\n",
      "    lines, lnum = findsource(frame)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/site-packages/IPython/core/ultratb.py\", line 177, in findsource\n",
      "    lines = linecache.getlines(file, globals_dict)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/linecache.py\", line 47, in getlines\n",
      "    return updatecache(filename, module_globals)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/linecache.py\", line 136, in updatecache\n",
      "    with tokenize.open(fullname) as fp:\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/tokenize.py\", line 454, in open\n",
      "    encoding, lines = detect_encoding(buffer.readline)\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/tokenize.py\", line 423, in detect_encoding\n",
      "    first = read_or_stop()\n",
      "  File \"/Users/zeweichu/anaconda3/envs/py36/lib/python3.6/tokenize.py\", line 381, in read_or_stop\n",
      "    return readline()\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m"
     ]
    }
   ],
   "source": [
    "# Start training it\n",
    "BATCH_SIZE = 128\n",
    "for epoch in range(10000):\n",
    "    for start in range(0, len(trX), BATCH_SIZE):\n",
    "        end = start + BATCH_SIZE\n",
    "        batchX = trX[start:end]\n",
    "        batchY = trY[start:end]\n",
    "\n",
    "        y_pred = model(batchX)\n",
    "        loss = loss_fn(y_pred, batchY)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # Find loss on training data\n",
    "    loss = loss_fn(model(trX), trY).item()\n",
    "    print('Epoch:', epoch, 'Loss:', loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后我们用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Output now\n",
    "testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])\n",
    "with torch.no_grad():\n",
    "    testY = model(testX)\n",
    "predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))\n",
    "\n",
    "print([fizz_buzz_decode(i, x) for (i, x) in predictions])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print(np.sum(testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])))\n",
    "testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
